Skip to content

Commit

Permalink
clang-format, black, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 5, 2023
1 parent c8c56fe commit e9b1abb
Show file tree
Hide file tree
Showing 20 changed files with 906 additions and 559 deletions.
5 changes: 5 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
BasedOnStyle: LLVM
AlignAfterOpenBracket: BlockIndent
BinPackArguments: false
BinPackParameters: false
IndentWidth: 4
6 changes: 6 additions & 0 deletions .clangd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CompileFlags:
Add:
- --cuda-gpu-arch=sm_75
Remove:
- --generate-code=*
- -forward-unknown-to-host-compiler
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*build
compile_commands.json
128 changes: 83 additions & 45 deletions csrc/backward.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#include "helpers.cuh"
#include "backward.cuh"
#include "helpers.cuh"
#include <cooperative_groups.h>

namespace cg = cooperative_groups;


__global__ void rasterize_backward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand All @@ -27,8 +26,8 @@ __global__ void rasterize_backward_kernel(
uint32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x;
unsigned i = blockIdx.y * blockDim.y + threadIdx.y;
unsigned j = blockIdx.x * blockDim.x + threadIdx.x;
float px = (float) j;
float py = (float) i;
float px = (float)j;
float py = (float)i;
uint32_t pix_id = i * img_size.x + j;

// which gaussians get gradients for this pixel
Expand All @@ -43,18 +42,18 @@ __global__ void rasterize_backward_kernel(
float3 S = {0.f, 0.f, 0.f};
int bin_final = final_index[pix_id];

// iterate backward to compute the jacobians wrt rgb, opacity, mean2d, and conic
// recursively compute T_{n-1} from T_n, where T_i = prod(j < i) (1 - alpha_j),
// and S_{n-1} from S_n, where S_j = sum_{i > j}(rgb_i * alpha_i * T_i)
// df/dalpha_i = rgb_i * T_i - S_{i+1| / (1 - alpha_i)
// iterate backward to compute the jacobians wrt rgb, opacity, mean2d, and
// conic recursively compute T_{n-1} from T_n, where T_i = prod(j < i) (1 -
// alpha_j), and S_{n-1} from S_n, where S_j = sum_{i > j}(rgb_i * alpha_i *
// T_i) df/dalpha_i = rgb_i * T_i - S_{i+1| / (1 - alpha_i)
for (int idx = bin_final - 1; idx >= range.x; --idx) {
uint32_t g = gaussians_ids_sorted[idx];
conic = conics[g];
center = xys[g];
delta = {center.x - px, center.y - py};
sigma = 0.5f * (
conic.x * delta.x * delta.x + conic.z * delta.y * delta.y
) - conic.y * delta.x * delta.y;
sigma =
0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) -
conic.y * delta.x * delta.y;
if (sigma <= 0.f) {
continue;
}
Expand Down Expand Up @@ -93,10 +92,15 @@ __global__ void rasterize_backward_kernel(
atomicAdd(&(v_conic[g].x), v_sigma * delta.x * delta.x);
atomicAdd(&(v_conic[g].y), v_sigma * delta.x * delta.y);
atomicAdd(&(v_conic[g].z), v_sigma * delta.y * delta.y);
atomicAdd(&(v_xy[g].x), v_sigma * 2.f * (conic.x * delta.x + conic.y * delta.y));
atomicAdd(&(v_xy[g].y), v_sigma * 2.f * (conic.y * delta.x + conic.z * delta.y));
atomicAdd(
&(v_xy[g].x),
v_sigma * 2.f * (conic.x * delta.x + conic.y * delta.y)
);
atomicAdd(
&(v_xy[g].y),
v_sigma * 2.f * (conic.y * delta.x + conic.z * delta.y)
);
}

}

void rasterize_backward_impl(
Expand All @@ -116,9 +120,9 @@ void rasterize_backward_impl(
float3 *v_conic,
float3 *v_rgb,
float *v_opacity

) {
rasterize_backward_kernel <<< tile_bounds, block >>> (
rasterize_backward_kernel<<<tile_bounds, block>>>(
tile_bounds,
img_size,
gaussians_ids_sorted,
Expand All @@ -137,7 +141,6 @@ void rasterize_backward_impl(
);
}


__global__ void project_gaussians_backward_kernel(
const int num_points,
const float3 *means3d,
Expand All @@ -160,7 +163,7 @@ __global__ void project_gaussians_backward_kernel(
float3 *v_scale,
float4 *v_quat
) {
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
if (idx >= num_points || radii[idx] <= 0) {
return;
}
Expand All @@ -183,7 +186,12 @@ __global__ void project_gaussians_backward_kernel(
);
// get v_scale and v_quat
scale_rot_to_cov3d_vjp(
scales[idx], glob_scale, quats[idx], &(v_cov3d[6 * idx]), v_scale[idx], v_quat[idx]
scales[idx],
glob_scale,
quats[idx],
&(v_cov3d[6 * idx]),
v_scale[idx],
v_quat[idx]
);
}

Expand All @@ -209,8 +217,9 @@ void project_gaussians_backward_impl(
float3 *v_scale,
float4 *v_quat
) {
project_gaussians_backward_kernel
<<< (num_points + N_THREADS - 1) / N_THREADS, N_THREADS >>> (
project_gaussians_backward_kernel<<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
means3d,
scales,
Expand Down Expand Up @@ -248,33 +257,57 @@ __device__ void project_cov3d_ewa_vjp(
// viewmat is row major, glm is column major
// upper 3x3 submatrix
glm::mat3 W = glm::mat3(
viewmat[0], viewmat[4], viewmat[8],
viewmat[1], viewmat[5], viewmat[9],
viewmat[2], viewmat[6], viewmat[10]
viewmat[0],
viewmat[4],
viewmat[8],
viewmat[1],
viewmat[5],
viewmat[9],
viewmat[2],
viewmat[6],
viewmat[10]
);
glm::vec3 p = glm::vec3(viewmat[3], viewmat[7], viewmat[11]);
glm::vec3 t = W * glm::vec3(mean3d.x, mean3d.y, mean3d.z) + p;

glm::mat3 J = glm::mat3(
fx / t.z, 0.f, 0.f,
0.f, fy / t.z, 0.f,
-fx * t.x / (t.z * t.z), -fy * t.y / (t.z * t.z), 0.f
fx / t.z,
0.f,
0.f,
0.f,
fy / t.z,
0.f,
-fx * t.x / (t.z * t.z),
-fy * t.y / (t.z * t.z),
0.f
);

glm::mat3 T = J * W;

glm::mat3 V = glm::mat3(
cov3d[0], cov3d[1], cov3d[2],
cov3d[1], cov3d[3], cov3d[4],
cov3d[2], cov3d[4], cov3d[5]
cov3d[0],
cov3d[1],
cov3d[2],
cov3d[1],
cov3d[3],
cov3d[4],
cov3d[2],
cov3d[4],
cov3d[5]
);

// df/dcov is nonzero only in upper 2x2 submatrix,
// bc we crop, so no gradients elsewhere
glm::mat3 v_cov = glm::mat3(
v_cov2d.x, 0.5f * v_cov2d.y, 0.f,
0.5f * v_cov2d.y, v_cov2d.z, 0.f,
0.f, 0.f, 0.f
v_cov2d.x,
0.5f * v_cov2d.y,
0.f,
0.5f * v_cov2d.y,
v_cov2d.z,
0.f,
0.f,
0.f,
0.f
);

// cov = T * V * Tt; G = df/dcov
Expand Down Expand Up @@ -304,15 +337,14 @@ __device__ void project_cov3d_ewa_vjp(
glm::vec3 v_t = glm::vec3(
-fx * rz2 * v_J[0][2],
-fy * rz2 * v_J[1][2],
-fx * rz2 * v_J[0][0] + 2.f * fx * t.x * rz3 * v_J[0][2]
- fy * rz2 * v_J[1][1] + 2.f * fy * t.y * rz3 * v_J[1][2]
-fx * rz2 * v_J[0][0] + 2.f * fx * t.x * rz3 * v_J[0][2] -
fy * rz2 * v_J[1][1] + 2.f * fy * t.y * rz3 * v_J[1][2]
);
v_mean3d.x += (float) glm::dot(v_t, W[0]);
v_mean3d.y += (float) glm::dot(v_t, W[1]);
v_mean3d.z += (float) glm::dot(v_t, W[2]);
v_mean3d.x += (float)glm::dot(v_t, W[0]);
v_mean3d.y += (float)glm::dot(v_t, W[1]);
v_mean3d.z += (float)glm::dot(v_t, W[2]);
}


// given cotangent v in output space (e.g. d_L/d_cov3d) in R(6)
// compute vJp for scale and rotation
__device__ void scale_rot_to_cov3d_vjp(
Expand All @@ -327,9 +359,15 @@ __device__ void scale_rot_to_cov3d_vjp(
// off-diagonal elements count grads from both ij and ji elements,
// must halve when expanding back into symmetric matrix
glm::mat3 v_V = glm::mat3(
v_cov3d[0], 0.5 * v_cov3d[1], 0.5 * v_cov3d[2],
0.5 * v_cov3d[1], v_cov3d[3], 0.5 * v_cov3d[4],
0.5 * v_cov3d[2], 0.5 * v_cov3d[4], v_cov3d[5]
v_cov3d[0],
0.5 * v_cov3d[1],
0.5 * v_cov3d[2],
0.5 * v_cov3d[1],
v_cov3d[3],
0.5 * v_cov3d[4],
0.5 * v_cov3d[2],
0.5 * v_cov3d[4],
v_cov3d[5]
);
glm::mat3 R = quat_to_rotmat(quat);
glm::mat3 S = scale_to_mat(scale, glob_scale);
Expand All @@ -339,9 +377,9 @@ __device__ void scale_rot_to_cov3d_vjp(
// df/dW = G * XT, df/dX = WT * G
glm::mat3 v_M = 2.f * v_V * M;
// glm::mat3 v_S = glm::transpose(R) * v_M;
v_scale.x = (float) glm::dot(R[0], v_M[0]);
v_scale.y = (float) glm::dot(R[1], v_M[1]);
v_scale.z = (float) glm::dot(R[2], v_M[2]);
v_scale.x = (float)glm::dot(R[0], v_M[0]);
v_scale.y = (float)glm::dot(R[1], v_M[1]);
v_scale.z = (float)glm::dot(R[2], v_M[2]);

glm::mat3 v_R = v_M * S;
v_quat = quat_to_rotmat_vjp(quat, v_R);
Expand Down
88 changes: 45 additions & 43 deletions csrc/bindings.cu
Original file line number Diff line number Diff line change
@@ -1,69 +1,71 @@
#include <math.h>
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdio>
#include <iostream>
#include "bindings.h"
#include "forward.cuh"
#include "helpers.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <tuple>
#include <cstdio>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include "helpers.cuh"
#include "bindings.h"
#include <glm/glm.hpp>
#include <glm/gtc/type_ptr.hpp>
#include "forward.cuh"
#include <iostream>
#include <math.h>
#include <torch/extension.h>
#include <tuple>

namespace cg = cooperative_groups;


template <typename scalar_t>
__global__ void compute_cov2d_bounds_forward_kernel(
const int num_pts,
const scalar_t * __restrict__ A,
scalar_t * __restrict__ conics,
scalar_t * __restrict__ radii
){
unsigned row = cg::this_grid().thread_rank(); // same as threadIdx.x + blockIdx.x * blockDim.x;
if (row>=num_pts){return;}
const scalar_t *__restrict__ A,
scalar_t *__restrict__ conics,
scalar_t *__restrict__ radii
) {
unsigned row = cg::this_grid().thread_rank(
); // same as threadIdx.x + blockIdx.x * blockDim.x;
if (row >= num_pts) {
return;
}
int index = row * 3;

float3 conic;
float radius;
float3 cov2d{(float)A[index], (float)A[index+1], (float)A[index+2]};
compute_cov2d_bounds(cov2d,conic,radius);
float3 cov2d{(float)A[index], (float)A[index + 1], (float)A[index + 2]};
compute_cov2d_bounds(cov2d, conic, radius);

conics[index] = conic.x;
conics[index+1] = conic.y;
conics[index+2] = conic.z;
radii[row]=radius;
conics[index + 1] = conic.y;
conics[index + 2] = conic.z;
radii[row] = radius;
}


std::tuple<
torch::Tensor, // output conics
torch::Tensor // ouptut radii
>
compute_cov2d_bounds_forward_tensor(
const int num_pts,
torch::Tensor A
){
std::
tuple<
torch::Tensor, // output conics
torch::Tensor // ouptut radii
>
compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor A) {
CHECK_INPUT(A);

torch::Tensor conics = torch::zeros({num_pts, A.size(1)}, A.options().dtype(torch::kFloat32));
torch::Tensor radii = torch::zeros({num_pts, 1}, A.options().dtype(torch::kFloat32));
torch::Tensor conics =
torch::zeros({num_pts, A.size(1)}, A.options().dtype(torch::kFloat32));
torch::Tensor radii =
torch::zeros({num_pts, 1}, A.options().dtype(torch::kFloat32));

int blocks = (num_pts + N_THREADS - 1) / N_THREADS;
// instantiate kernel
AT_DISPATCH_FLOATING_TYPES(A.type(), "compute_cov2d_bounds_cu_forward",
([&] {
compute_cov2d_bounds_forward_kernel<scalar_t><<<blocks, N_THREADS>>>(
num_pts,
A.contiguous().data_ptr<scalar_t>(),
conics.contiguous().data_ptr<scalar_t>(),
radii.contiguous().data_ptr<scalar_t>()
);
})
AT_DISPATCH_FLOATING_TYPES(
A.type(), "compute_cov2d_bounds_cu_forward", ([&] {
compute_cov2d_bounds_forward_kernel<scalar_t>
<<<blocks, N_THREADS>>>(
num_pts,
A.contiguous().data_ptr<scalar_t>(),
conics.contiguous().data_ptr<scalar_t>(),
radii.contiguous().data_ptr<scalar_t>()
);
})
);
return std::make_tuple(conics, radii);
}
Loading

0 comments on commit e9b1abb

Please sign in to comment.