Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix in VoxelGrid code #95

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions modules/interpolate3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

# https://gist.github.com/Kulbear/af6499e83382df88c2a2c42fb3143652
import torch
import numpy as np

def gather_nd_torch(params, indices, batch_dim=1):
""" A PyTorch porting of tensorflow.gather_nd
This implementation can handle leading batch dimensions in params, see below for detailed explanation.

The majority of this implementation is from Michael Jungo @ https://stackoverflow.com/a/61810047/6670143
I just ported it compatible to leading batch dimension.

Args:
params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c].
indices: a tensor of dimension [b1, ..., bn, x, m]
batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n.

Returns:
gathered: a tensor of dimension [b1, ..., bn, x, c].

Example:
>>> batch_size = 5
>>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
>>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
>>> gathered = gather_nd_torch(inputs, pos, batch_dim=3)
>>> gathered.shape
torch.Size([5, 5, 5, 12, 32])

>>> inputs_tf = tf.convert_to_tensor(inputs.numpy())
>>> pos_tf = tf.convert_to_tensor(pos.numpy())
>>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
>>> gathered_tf.shape
TensorShape([5, 5, 5, 12, 32])

>>> gathered_tf = torch.from_numpy(gathered_tf.numpy())
>>> torch.equal(gathered_tf, gathered)
True
"""
batch_dims = params.size()[:batch_dim] # [b1, ..., bn]
batch_size = np.cumprod(list(batch_dims))[-1] # b1 * ... * bn
c_dim = params.size()[-1] # c
grid_dims = params.size()[batch_dim:-1] # [g1, ..., gm]
n_indices = indices.size(-2) # x
n_pos = indices.size(-1) # m

# reshape leadning batch dims to a single batch dim
params = params.reshape(batch_size, *grid_dims, c_dim)
indices = indices.reshape(batch_size, n_indices, n_pos)

# build gather indices
# gather for each of the data point in this "batch"
batch_enumeration = torch.arange(batch_size).unsqueeze(1)
gather_dims = [indices[:, :, i] for i in range(len(grid_dims))]
gather_dims.insert(0, batch_enumeration)
gathered = params[gather_dims]

# reshape back to the shape with leading batch dims
gathered = gathered.reshape(*batch_dims, n_indices, c_dim)
return gathered


def interpolate(grid_3d,
sampling_points):
"""Trilinear interpolation on a 3D regular grid.

This is a porting of TensorFlow Graphics implementation of trilinear interpolation.
Check https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/math/interpolation/trilinear.py
for more details.

Args:
grid_3d: A tensor with shape `[A1, ..., An, H, W, D, C]` where H, W, D are
height, width, depth of the grid and C is the number of channels.
sampling_points: A tensor with shape `[A1, ..., An, M, 3]` where M is the
number of sampling points. Sampling points outside the grid are projected
in the grid borders.

Returns:
A tensor of shape `[A1, ..., An, M, C]`
"""

grid_3d_shape = grid_3d.size()
sampling_points_shape = sampling_points.size()
voxel_cube_shape = grid_3d_shape[-4:-1] # [H, W, D]
batch_dims = sampling_points_shape[:-2] # [A1, ..., An]
num_points = sampling_points_shape[-2] # M

bottom_left = torch.floor(sampling_points)
top_right = bottom_left + 1
bottom_left_index = bottom_left.type(torch.int32)
top_right_index = top_right.type(torch.int32)

x0_index, y0_index, z0_index = torch.unbind(bottom_left_index, dim=-1)
x1_index, y1_index, z1_index = torch.unbind(top_right_index, dim=-1)
index_x = torch.concat([x0_index, x1_index, x0_index, x1_index,
x0_index, x1_index, x0_index, x1_index], dim=-1)
index_y = torch.concat([y0_index, y0_index, y1_index, y1_index,
y0_index, y0_index, y1_index, y1_index], dim=-1)
index_z = torch.concat([z0_index, z0_index, z0_index, z0_index,
z1_index, z1_index, z1_index, z1_index], dim=-1)
indices = torch.stack([index_x, index_y, index_z], dim=-1)

clip_value_max = (torch.tensor(list(voxel_cube_shape)) - 1).to(device=sampling_points.device)
clip_value_min = torch.zeros_like(clip_value_max).to(device=sampling_points.device)
indices = torch.clamp(indices, min=clip_value_min, max=clip_value_max)

content = gather_nd_torch(
params=grid_3d, indices=indices.long(), batch_dim=len(batch_dims))

distance_to_bottom_left = sampling_points - bottom_left
distance_to_top_right = top_right - sampling_points
x_x0, y_y0, z_z0 = torch.unbind(distance_to_bottom_left, dim=-1)
x1_x, y1_y, z1_z = torch.unbind(distance_to_top_right, dim=-1)
weights_x = torch.concat([x1_x, x_x0, x1_x, x_x0,
x1_x, x_x0, x1_x, x_x0], dim=-1)
weights_y = torch.concat([y1_y, y1_y, y_y0, y_y0,
y1_y, y1_y, y_y0, y_y0], dim=-1)
weights_z = torch.concat([z1_z, z1_z, z1_z, z1_z,
z_z0, z_z0, z_z0, z_z0], dim=-1)

weights = weights_x * weights_y * weights_z
weights = weights.unsqueeze(-1)

interpolated_values = weights * content

return sum(torch.split(interpolated_values, [num_points] * 8, dim=-2))
51 changes: 18 additions & 33 deletions modules/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .utils import morton3D, morton3D_invert, packbits
from .volume_train import VolumeRenderer
from .sh_utils import eval_sh
from .interpolate3d import interpolate


class TruncExp(torch.autograd.Function):
Expand Down Expand Up @@ -439,7 +440,8 @@ def initialize_grid(self):

Params:
grid_normalized_coords: (sx * sy * sz, 3), normalized coordinates of the grids
grid_fields: (sx, sy, sz, sh_dim + 1), data fields(sh and density) of the grids
sh_fields: (sx, sy, sz, sh_dim), data fields(sh) of the grids
density_fields: (sx, sy, sz, 1), data fields(density) of the grids
"""
if isinstance(self.grid_size, float) or isinstance(self.grid_size, int):
grid_res = [self.grid_size] * 3
Expand Down Expand Up @@ -484,7 +486,7 @@ def initialize_grid(self):
requires_grad=True,
)

self.grid_fields = torch.cat((self.sh_fields, self.density_fields), dim=3)
self.grid_shape = (grids.shape[0], grids.shape[1], grids.shape[2])

def out_of_grid(self, idx):
"""
Expand All @@ -499,7 +501,7 @@ def out_of_grid(self, idx):
x_idx, y_idx, z_idx = idx.unbind(-1)

# find which points are outside the grid
sx, sy, sz, _ = self.grid_fields.shape
sx, sy, sz = self.grid_shape
x_idx_valid = (x_idx < sx) & (x_idx >= 0)
y_idx_valid = (y_idx < sy) & (y_idx >= 0)
z_idx_valid = (z_idx < sz) & (z_idx >= 0)
Expand All @@ -511,7 +513,7 @@ def fix_out_of_grid(self, idx):
x_idx, y_idx, z_idx = idx.unbind(-1)

# find which points are outside the grid
sx, sy, sz, _ = self.grid_fields.shape
sx, sy, sz= self.grid_shape
x_idx %= sx
y_idx %= sy
z_idx %= sz
Expand All @@ -521,17 +523,6 @@ def fix_out_of_grid(self, idx):
def normalize_samples(self, pts):
return (pts- self.grid_normalized_coords.min(0)[0]) / self.grid_radius

def trilinear_interpolation(self, bundles, weight_a, weight_b):
c00 = bundles[0] * weight_a[:, 2:] + bundles[1] * weight_b[:, 2:]
c01 = bundles[2] * weight_a[:, 2:] + bundles[3] * weight_b[:, 2:]
c10 = bundles[4] * weight_a[:, 2:] + bundles[5] * weight_b[:, 2:]
c11 = bundles[6] * weight_a[:, 2:] + bundles[7] * weight_b[:, 2:]
c0 = c00 * weight_a[:, 1:2] + c01 * weight_b[:, 1:2]
c1 = c10 * weight_a[:, 1:2] + c11 * weight_b[:, 1:2]
results = c0 * weight_a[:, :1] + c1 * weight_b[:, :1]

return results

def query_grids(self, idx, use_trilinear=False):
"""
Query the grid fields at the given indices.
Expand All @@ -548,31 +539,25 @@ def query_grids(self, idx, use_trilinear=False):
idx_mask = self.out_of_grid(aligned_idx)
x_idx, y_idx, z_idx = self.fix_out_of_grid(aligned_idx)

query_results = self.grid_fields[x_idx, y_idx, z_idx]
query_results = query_results * idx_mask.unsqueeze(-1) # zero the samples that are out of the grid

if use_trilinear:
weight_b = torch.abs(idx - aligned_idx)
weight_a = 1.0 - weight_b
query_sh, query_density = query_results[..., :-1], query_results[..., -1]
samples_density = self.trilinear_interpolation(query_density, weight_a, weight_b)
samples_sh = self.trilinear_interpolation(query_sh, weight_a, weight_b)
samples_result = torch.cat((samples_sh, samples_density), dim=3)
return samples_result
samples_sh = interpolate(self.sh_fields.unsqueeze(0), idx.unsqueeze(0)).squeeze(0)
samples_density = interpolate(self.density_fields.unsqueeze(0), idx.unsqueeze(0)).squeeze(0)
return samples_sh, samples_density

return query_results
query_sh = self.sh_fields[x_idx, y_idx, z_idx] * idx_mask.unsqueeze(-1) # zero the samples that are out of the grid
query_density = self.density_fields[x_idx, y_idx, z_idx] * idx_mask.unsqueeze(-1) # zero the samples that are out of the grid
return query_sh, query_density


def forward(self, pts, dirs):
normalized_idx = self.normalize_samples(pts)
samples_result = self.query_grids(normalized_idx)
samples_sh, samples_density = samples_reuslt[..., :-1], samples_reuslt[..., -1]
samples_rgb = torch.empty((pts.shape(0), pts.shape(1), 3), device=samples_sh.device)
sh_dim = self.net.sh_dim
samples_sh, samples_density = self.query_grids(normalized_idx, use_trilinear=True)
samples_rgb = torch.empty((pts.shape[0], 3), device=samples_sh.device)
sh_dim = self.sh_dim
for i in range(3):
sh_coeffs = samples_sh[:, :, sh_dim*i:sh_dim*(i+1)]
samples_rgb[:, :, i] = eval_sh(self.sh_degree, sh_coeffs, viewdirs)
return samples_density, samples_rgb
sh_coeffs = samples_sh[..., sh_dim*i:sh_dim*(i+1)]
samples_rgb[..., i] = eval_sh(self.sh_degree, sh_coeffs, dirs)
return samples_density.squeeze(-1), samples_rgb


MODEL_DICT = {
Expand Down