Skip to content

Experiments (WIP)

lightfield botanist edited this page Jun 25, 2022 · 85 revisions

Experiments

Visualize nerf in less than 100 line of code with pytorch

https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666

  • Positional encoding
  • The radiance field function approximator (in this case, an MLP)
  • Differentiable volume renderer
  • Stratified sampling
  • Hierarchical volume sampling 4-10 layers, Positional encoding (frequencies)

NeRF, inspired by this representation, attempts to approximate a function that maps from this space into a 4D space consisting of color c=(R,G,B) and a density σ, which you can think of as the likelihood that the light ray at this 5D coordinate space is terminated (e.g. by occlusion). The standard NeRF is thus a function of the form F : (x,d) -> (c,σ)

The original NeRF paper parameterizes this function with a multilayer perceptron trained directly on a set of images with known poses (which can be obtained via any structure-from-motion application such as COLMAP, Agisoft Metashape, Reality Capture, or Meshroom).

Hierarchical sampling Coarser and fine

If you can't explain it simply, you don't understand it well enough.

NeRF proposes a hierarchical structure. The overall network architecture is composed of two networks: the coarse network and the fine network.

The model is a multi-layer perceptron (MLP), with ReLU as its non-linearity.

MLP first processes the input 3D coordinate x with 8 fully-connected layers (using ReLU activations and 256 channels per layer), and outputs sigma and a 256-dimensional feature vector. This feature vector is then concatenated with the camera ray's viewing direction and passed to one additional fully-connected layer (using a ReLU activation and 128 channels) that output the view-dependent RGB color."*

class NerfModel(nn.Module): # pytorch
    def __init__(self, embedding_dim_pos=20, embedding_dim_direction=8, hidden_dim=128):
        super(NerfModel, self).__init__()
        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 3 + hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 3 + hidden_dim, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()
 model = NerfModel(hidden_dim=256).to(device)
def get_nerf_model(num_layers, num_pos):
    """Generates the NeRF neural network.
    Args:
        num_layers: The number of MLP layers.
        num_pos: The number of dimensions of positional encoding.
    Returns:         The `tf.keras` model.
    """
    inputs = keras.Input(shape=(num_pos, 2 * 3 * POS_ENCODE_DIMS + 3))
    x = inputs
    for i in range(num_layers):
        x = layers.Dense(units=64, activation="relu")(x)
        if i % 4 == 0 and i > 0:
            # Inject residual connection.
            x = layers.concatenate([x, inputs], axis=-1)
    outputs = layers.Dense(units=4)(x)
    return keras.Model(inputs=inputs, outputs=outputs)
class DirectTemporalNeRF(nn.Module): #TiNeuVox
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, input_ch_time=1, output_ch=4, skips=[4],
                 use_viewdirs=False, memory=[], embed_fn=None, zero_canonical=True):
        super(DirectTemporalNeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.input_ch_time = input_ch_time
        self.skips = skips
        self.use_viewdirs = use_viewdirs
        self.memory = memory
        self.embed_fn = embed_fn
        self.zero_canonical = zero_canonical

        self._occ = NeRFOriginal(D=D, W=W, input_ch=input_ch, input_ch_views=input_ch_views,
                                 input_ch_time=input_ch_time, output_ch=output_ch, skips=skips,
                                 use_viewdirs=use_viewdirs, memory=memory, embed_fn=embed_fn, output_color_ch=3)
        self._time, self._time_out = self.create_time_net()

    def create_time_net(self):
        layers = [nn.Linear(self.input_ch + self.input_ch_time, self.W)]
        for i in range(self.D - 1):
            if i in self.memory:
                raise NotImplementedError
            else:
                layer = nn.Linear

            in_channels = self.W
            if i in self.skips:
                in_channels += self.input_ch

            layers += [layer(in_channels, self.W)]
        return nn.ModuleList(layers), nn.Linear(self.W, 3)

    def forward(self, x, ts):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        t = ts[0]

        assert len(torch.unique(t[:, :1])) == 1, "Only accepts all points from same time"
        cur_time = t[0, 0]
        if cur_time == 0. and self.zero_canonical:
            dx = torch.zeros_like(input_pts[:, :3])
        else:
            dx = self.query_time(input_pts, t, self._time, self._time_out)
            input_pts_orig = input_pts[:, :3]
            input_pts = self.embed_fn(input_pts_orig + dx)
        out, _ = self._occ(torch.cat([input_pts, input_views], dim=-1), t)
        return out, dx

class NeRFOriginal(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, input_ch_time=1, output_ch=4, skips=[4],
                 use_viewdirs=False, memory=[], embed_fn=None, output_color_ch=3, zero_canonical=True):
        super(NeRFOriginal, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs

        # self.pts_linears = nn.ModuleList(
        #     [nn.Linear(input_ch, W)] +
        #     [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])

        layers = [nn.Linear(input_ch, W)]
        for i in range(D - 1):
            if i in memory:
                raise NotImplementedError
            else:
                layer = nn.Linear
            in_channels = W
            if i in self.skips:
                in_channels += input_ch
            layers += [layer(in_channels, W)]
        self.pts_linears = nn.ModuleList(layers)

        ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])

        ### Implementation according to the paper
        # self.views_linears = nn.ModuleList(
        #     [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
        if use_viewdirs:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W//2, output_color_ch)
        else:
            self.output_linear = nn.Linear(W, output_ch)

    def forward(self, x, ts):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)
        return outputs, torch.zeros_like(input_pts[:, :3])

https://github.com/3a1b2c3/Minimal-NeRF/blob/main/nerf_helpers.py#L162

def view_reconstruction(model, all_o_rays, all_d_rays, N=4096):
    """Queries the model at every ray direction to generate an image from a view.
    
    Args:
        model: a nerf_model.ImageNeRFModel object 
        all_o_rays: [H x W x 3] vector of 3D origins. (should all be identical)
        all_d_rays: [H x W x 3] vector of directions.
        N: batch size to pass through model.
    Returns:
        an [im_h x im_w x 3] numpy array representing an image.
    """
    H, W, C = all_o_rays.shape
    all_o_rays = all_o_rays.view((H*W, C))
    all_d_rays = all_d_rays.view((H*W, C))
    im = []
    for i in range(0, H*W, N): 
        recon_preds = model.forward(all_o_rays[i:min(H*W,i+N),:], all_d_rays[i:min(H*W,i+N),:])
        im.append(recon_preds['fine_rgb_rays'].cpu().clone().detach().numpy())
    im = np.concatenate(im, axis=0).reshape((H, W, C))
    im *= 255
    im = np.clip(im, 0, 255)
    return im.astype(np.uint8)

def photo_nerf_to_image(model, im_h, im_w): 
    """Queries the model at every idx to generate an image 
    
    Args:
        model: a nerf_model.ImageNeRFModel object 
        im_h: the height of the image 
        im_w: the width of the image 
    Returns:
        an [im_h x im_w x 3] tensor representing an image.
    """
    if type(im_h) != int:
        im_h = int(im_h[0])
        im_w = int(im_w[0])
    idxs = [(i,j) for i,j in itertools.product(np.arange(0,im_h), np.arange(0,im_w))]
    idxs = torch.FloatTensor(idxs).to(model.device)
    idxs[:,0] /= (im_h-1)
    idxs[:,1] /= (im_w-1)
    N, _ = idxs.shape
    recon = []
    step = 4096
    for i in range(0, N, step):
        # break up whole tensor into sizeable chunks
        batch = idxs[i:i+step,:]
        rgb = model(batch)
        recon.append(rgb)
    recon = torch.cat(recon, axis=0).reshape((im_h, im_w, 3))
    return recon

Shoot a ray through each pixel and sample some points on the ray. A ray is parameterized by the equation r(t) = o + td where t is the parameter, o is the origin and d is the unit directional vector. The ray has a viewing angle (theta, phi). These sampled points act as the input to the NeRF model. The model is then asked to predict the RGB color and the volume density at that point.

Nerf_pl pytorch implementation (https://github.com/kwea123/nerf_pl)

The code is largely based on NeRF implementation (see master or dev branch), the main difference is the model structure and the rendering process, which can be found in the two files under models/.

Local Light Field Fusion (LLFF) [28] LLFF is designed for producing photorealistic novel views for well-sampled forward facing scenes. It uses a trained 3D convolutional network to directly predict a discretized frustum-sampled RGBα grid (multiplane image or MPI [52]) for each input view, then renders novel views by alpha compositing and blending nearby MPIs into the novel viewport

The regularizer coalesces distributed density (along each ray) to a minimal spaced discrete sample where possible. Basically it squeezes soft data to a solid surface. This is what's responsible for eliminating all the floaters.

More Hashnerf (pytorch)

enum class GridType {
	Hash,
	Dense,
	Tiled,
};

https://github.com/3a1b2c3/torch-ngp

Rendering in different implementations

def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
    """Generates the RGB image and depth map from model prediction.

    Args:
        model: The MLP model that is trained to predict the rgb and
            volume density of the volumetric scene.
        rays_flat: The flattened rays that serve as the input to
            the NeRF model.
        t_vals: The sample points for the rays.
        rand: Choice to randomise the sampling strategy.
        train: Whether the model is in the training or testing phase.

    Returns:
        Tuple of rgb image and depth map.
    """
    # Get the predictions from the nerf model and reshape it.
    if train:
        predictions = model(rays_flat)
    else:
        predictions = model.predict(rays_flat)
    predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))

    # Slice the predictions into rgb and sigma.
    rgb = tf.sigmoid(predictions[..., :-1])
    sigma_a = tf.nn.relu(predictions[..., -1])

    # Get the distance of adjacent intervals.
    delta = t_vals[..., 1:] - t_vals[..., :-1]
    # delta shape = (num_samples)
    if rand:
        delta = tf.concat(
            [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
        )
        alpha = 1.0 - tf.exp(-sigma_a * delta)
    else:
        delta = tf.concat(
            [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
        )
        alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])

    # Get transmittance.
    exp_term = 1.0 - alpha
    epsilon = 1e-10
    transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
    weights = alpha * transmittance
    rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)

    if rand:
        depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
    else:
        depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
    return (rgb, depth_map)
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    T = compute_accumulated_transmittance(1 - alpha)   # [batch_size, nb_bins]
    return (T.unsqueeze(2) * alpha.unsqueeze(2) * colors).sum(dim=1)  # Pixel values

Original nerf

We synthesize views by querying 5D coordinates along camera rays and use classic volume rendering techniques to project the output colors and densities into an image. Because volume rendering is naturally differentiable, the only input required to optimize our representation is a set of images with known camera poses https://github.com/bmild/nerf/blob/20a91e764a28816ee2234fcadb73bd59a613a44c/run_nerf.py#L60 Hash_nerf

TiNeuVox

Web

Real time demo

https://www.youtube.com/watch?v=-eJHNvGQYmY https://www.pexels.com/video/drone-shot-of-forest-856478/

  • https://phog.github.io/snerg/ Baking Neural Radiance Fields for Real-Time View Synthesis image Our method precomputes and stores ("bakes") a NeRF into a Sparse Neural Radiance Grid (SNeRG) data structure. In order to render our SNeRG data structure in real time, we:

    Use a sparse voxel grid to skip empty space along rays Look up a diffuse color for each point sampled along a ray in occupied space, and composite these along the ray Look up a feature vector (4-dimensional) for each point, and composite these along the ray Decode the composited features into a single specular color per pixel using a tiny (2 layers, 16 channels) MLP Add the diffuse and specular color components to compute the final RGB color

Photogrammetry vs Nerf meshes

https://github.com/hustvl/TiNeuVox

ReLU Fields

https://geometry.cs.ucl.ac.uk/group_website/projects/2022/relu_fields