Skip to content

Commit

Permalink
Add clamping after SH evaluation as in original 3DGS implementation (#76
Browse files Browse the repository at this point in the history
)

* Add clamping for splat rgb as in original 3dgs implementation

* Update reference gradients, manage reference env with uv

* Spelling

* Remove now unnecesary clamps

* Slightly lower threshold (windows passes)

* Revert precision

---------

Co-authored-by: Arthur Brussee <[email protected]>
  • Loading branch information
fhahlbohm and ArthurBrussee authored Jan 6, 2025
1 parent c4351bd commit c2c05b6
Show file tree
Hide file tree
Showing 16 changed files with 1,861 additions and 188 deletions.
2 changes: 0 additions & 2 deletions crates/brush-render/src/shaders/project_visible.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
let viewdir = normalize(mean - uniforms.camera_position.xyz);

var color = sh_coeffs_to_color(sh_degree, viewdir, sh) + vec3f(0.5);
// TODO: This would be good but need to update backwards gradient as well.
// color = max(color, vec3f(0.0));

projected[compact_gid] = helpers::create_projected_splat(
mean2d,
Expand Down
3 changes: 2 additions & 1 deletion crates/brush-render/src/shaders/rasterize.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ fn main(
}

let fac = alpha * T;
pix_out += vec3f(color.r, color.g, color.b) * fac;
let clamped_rgb = max(color.rgb, vec3f(0.0));
pix_out += clamped_rgb * fac;
T = next_T;

let isect_id = batch_start + t;
Expand Down
8 changes: 5 additions & 3 deletions crates/brush-render/src/shaders/rasterize_backwards.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,12 @@ fn main(
let fac = alpha * T;

// contribution from this pixel
var v_alpha = dot(color.rgb * T - buffer * ra, v_out.rgb);
let clamped_rgb = max(color.rgb, vec3f(0.0));
var v_alpha = dot(clamped_rgb * T - buffer * ra, v_out.rgb);
v_alpha += T_final * ra * v_out.a;

// update the running sum
buffer += color.xyz * fac;
buffer += clamped_rgb * fac;

let v_sigma = -color.a * vis * v_alpha;

Expand All @@ -269,7 +270,8 @@ fn main(
v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y);

v_colors = vec4f(fac * v_out.rgb, vis * v_alpha);
v_colors = vec4f(select(fac * v_out.rgb, vec3f(0.0), color.rgb < vec3f(0.0)),
vis * v_alpha);
}
}

Expand Down
5 changes: 3 additions & 2 deletions crates/brush-render/src/tests/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async fn test_reference() -> Result<()> {
let fov_y = focal_to_fov(focal, h as u32);

let cam = Camera::new(
glam::vec3(0.123, -0.123, -8.0),
glam::vec3(0.0, 0.0, -8.0),
glam::Quat::IDENTITY,
fov_x,
fov_y,
Expand Down Expand Up @@ -111,9 +111,10 @@ async fn test_reference() -> Result<()> {
)?;
}

wrapped_aux.clone().debug_assert_valid();

// Check if images match.
assert!(out.clone().all_close(img_ref, Some(1e-5), Some(1e-6)));
wrapped_aux.clone().debug_assert_valid();

let num_visible = wrapped_aux.num_visible.into_scalar_async().await as usize;
let projected_splats =
Expand Down
1 change: 1 addition & 0 deletions crates/brush-render/test_cases/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
196 changes: 24 additions & 172 deletions crates/brush-render/test_cases/NerfStudioRefGen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 16,
"metadata": {
"id": "2vhrNCZ4rrap"
},
Expand Down Expand Up @@ -64,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -76,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -117,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down Expand Up @@ -155,6 +155,7 @@
" \n",
" colors = spherical_harmonics(3, dirs, coeffs[None], masks=None) # [C, N, 3]\n",
" colors = colors + 0.5\n",
" colors = colors.clamp(min=0.0)\n",
"\n",
" render_colors, render_alphas, info = rasterization(\n",
" means=means,\n",
Expand Down Expand Up @@ -214,8 +215,15 @@
" \"out_img\": out_img,\n",
" \"v_out_img\": out_img.grad,\n",
" }\n",
" save_file(tensors, f\"./{name}.safetensors\")\n",
"\n",
" save_file(tensors, f\"./{name}.safetensors\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Super simple case: a few splats visible in a tiny image.\n",
"def test_case():\n",
" torch.manual_seed(14)\n",
Expand Down Expand Up @@ -249,7 +257,7 @@
"outputs": [],
"source": [
"\n",
"# Simple case: a few splats visible in a tiny image.\n",
"# Basic case: a bunch of splats visible.\n",
"def test_case():\n",
" torch.manual_seed(3)\n",
" num_points = 16\n",
Expand Down Expand Up @@ -282,7 +290,7 @@
"outputs": [],
"source": [
"\n",
"# Super simple case: a few splats visible in a tiny image.\n",
"# Bigger test case: Lots of splats saturaing the image.\n",
"def test_case():\n",
" torch.manual_seed(4)\n",
" num_points = 76873\n",
Expand Down Expand Up @@ -310,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -377,22 +385,9 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'project_gaussians' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[19], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mtest_render\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[1;32mIn[18], line 29\u001b[0m, in \u001b[0;36mtest_render\u001b[1;34m()\u001b[0m\n\u001b[0;32m 27\u001b[0m w, h \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m512\u001b[39m, \u001b[38;5;241m512\u001b[39m)\n\u001b[0;32m 28\u001b[0m cam \u001b[38;5;241m=\u001b[39m basic_camera(w, h)\n\u001b[1;32m---> 29\u001b[0m xys, depths, radii, conics, _comp, num_tiles_hit, _cov3d \u001b[38;5;241m=\u001b[39m \u001b[43mproject_gaussians\u001b[49m(\n\u001b[0;32m 30\u001b[0m means,\n\u001b[0;32m 31\u001b[0m log_scales\u001b[38;5;241m.\u001b[39mexp(),\n\u001b[0;32m 32\u001b[0m \u001b[38;5;241m1\u001b[39m,\n\u001b[0;32m 33\u001b[0m quats,\n\u001b[0;32m 34\u001b[0m cam\u001b[38;5;241m.\u001b[39mviewmat,\n\u001b[0;32m 35\u001b[0m cam\u001b[38;5;241m.\u001b[39mfocal,\n\u001b[0;32m 36\u001b[0m cam\u001b[38;5;241m.\u001b[39mfocal,\n\u001b[0;32m 37\u001b[0m cam\u001b[38;5;241m.\u001b[39mw \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m,\n\u001b[0;32m 38\u001b[0m cam\u001b[38;5;241m.\u001b[39mh \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m,\n\u001b[0;32m 39\u001b[0m cam\u001b[38;5;241m.\u001b[39mh,\n\u001b[0;32m 40\u001b[0m cam\u001b[38;5;241m.\u001b[39mw,\n\u001b[0;32m 41\u001b[0m DEFAULT_TILE_SIZE,\n\u001b[0;32m 42\u001b[0m \u001b[38;5;241m0.01\u001b[39m\n\u001b[0;32m 43\u001b[0m )\n\u001b[0;32m 44\u001b[0m viewdirs \u001b[38;5;241m=\u001b[39m means \u001b[38;5;241m-\u001b[39m cam\u001b[38;5;241m.\u001b[39mviewmat[:\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m3\u001b[39m] \u001b[38;5;66;03m# (N, 3)\u001b[39;00m\n\u001b[0;32m 45\u001b[0m colors \u001b[38;5;241m=\u001b[39m spherical_harmonics(\u001b[38;5;241m0\u001b[39m, viewdirs, coeffs) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m0.5\u001b[39m\n",
"\u001b[1;31mNameError\u001b[0m: name 'project_gaussians' is not defined"
]
}
],
"outputs": [],
"source": [
"test_render()"
]
Expand Down Expand Up @@ -470,80 +465,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fwd base\n",
"bench_times: 0.0054480000000003415\n",
"bench_times: 0.006264699999999124\n",
"bench_times: 0.007424999999999571\n",
"bench_times: 0.008630799999999716\n",
"bench_times: 0.010573450000000761\n",
"bench_times: 0.012653049999999055\n",
"bench_times: 0.014533000000001017\n",
"bench_times: 0.016140749999999926\n",
"bench_times: 0.018148449999999983\n",
"bench_times: 0.020353200000000626\n",
"fwd dense\n",
"bench_times: 0.005170599999999581\n",
"bench_times: 0.0065869999999996764\n",
"bench_times: 0.007834599999998915\n",
"bench_times: 0.009565600000000174\n",
"bench_times: 0.012797100000000228\n",
"bench_times: 0.0159091999999994\n",
"bench_times: 0.018979899999999716\n",
"bench_times: 0.021981800000000717\n",
"bench_times: 0.02528674999999936\n",
"bench_times: 0.028196200000000005\n",
"fwd hd\n",
"bench_times: 0.0044535000000021085\n",
"bench_times: 0.006078350000001009\n",
"bench_times: 0.007822550000000206\n",
"bench_times: 0.01049919999999993\n",
"bench_times: 0.013620799999999988\n",
"bench_times: 0.01567149999999984\n",
"bench_times: 0.018067799999998968\n",
"bench_times: 0.019659700000000058\n",
"bench_times: 0.022230549999999738\n",
"bench_times: 0.02517940000000074\n",
"bwd base\n",
"bench_times: 0.008292050000001439\n",
"bench_times: 0.012299000000000504\n",
"bench_times: 0.018046149999999983\n",
"bench_times: 0.02415445000000105\n",
"bench_times: 0.03085374999999857\n",
"bench_times: 0.03746719999999826\n",
"bench_times: 0.04361719999999991\n",
"bench_times: 0.05000055000000003\n",
"bench_times: 0.05620219999999776\n",
"bench_times: 0.061733999999997735\n",
"bwd dense\n",
"bench_times: 0.015267299999997874\n",
"bench_times: 0.021559500000002174\n",
"bench_times: 0.022710149999998208\n",
"bench_times: 0.025787900000000974\n",
"bench_times: 0.030729000000004447\n",
"bench_times: 0.0361668499999972\n",
"bench_times: 0.042472500000002356\n",
"bench_times: 0.04797824999999989\n",
"bench_times: 0.053266199999999486\n",
"bench_times: 0.058767050000000154\n",
"bwd hd\n",
"bench_times: 0.010913299999998571\n",
"bench_times: 0.017144649999998762\n",
"bench_times: 0.025005099999997782\n",
"bench_times: 0.033651499999997725\n",
"bench_times: 0.04151339999999948\n",
"bench_times: 0.05062085000000138\n",
"bench_times: 0.05931074999999453\n",
"bench_times: 0.06616230000000911\n",
"bench_times: 0.07412374999999827\n",
"bench_times: 0.08495630000000176\n"
]
}
],
"outputs": [],
"source": [
"gc.collect()\n",
"\n",
Expand All @@ -563,77 +485,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'fwd_base': [0.0054480000000003415,\n",
" 0.006264699999999124,\n",
" 0.007424999999999571,\n",
" 0.008630799999999716,\n",
" 0.010573450000000761,\n",
" 0.012653049999999055,\n",
" 0.014533000000001017,\n",
" 0.016140749999999926,\n",
" 0.018148449999999983,\n",
" 0.020353200000000626],\n",
" 'fwd_dense': [0.005170599999999581,\n",
" 0.0065869999999996764,\n",
" 0.007834599999998915,\n",
" 0.009565600000000174,\n",
" 0.012797100000000228,\n",
" 0.0159091999999994,\n",
" 0.018979899999999716,\n",
" 0.021981800000000717,\n",
" 0.02528674999999936,\n",
" 0.028196200000000005],\n",
" 'fwd_hd': [0.0044535000000021085,\n",
" 0.006078350000001009,\n",
" 0.007822550000000206,\n",
" 0.01049919999999993,\n",
" 0.013620799999999988,\n",
" 0.01567149999999984,\n",
" 0.018067799999998968,\n",
" 0.019659700000000058,\n",
" 0.022230549999999738,\n",
" 0.02517940000000074],\n",
" 'bwd_base': [0.008292050000001439,\n",
" 0.012299000000000504,\n",
" 0.018046149999999983,\n",
" 0.02415445000000105,\n",
" 0.03085374999999857,\n",
" 0.03746719999999826,\n",
" 0.04361719999999991,\n",
" 0.05000055000000003,\n",
" 0.05620219999999776,\n",
" 0.061733999999997735],\n",
" 'bwd_dense': [0.015267299999997874,\n",
" 0.021559500000002174,\n",
" 0.022710149999998208,\n",
" 0.025787900000000974,\n",
" 0.030729000000004447,\n",
" 0.0361668499999972,\n",
" 0.042472500000002356,\n",
" 0.04797824999999989,\n",
" 0.053266199999999486,\n",
" 0.058767050000000154],\n",
" 'bwd_hd': [0.010913299999998571,\n",
" 0.017144649999998762,\n",
" 0.025005099999997782,\n",
" 0.033651499999997725,\n",
" 0.04151339999999948,\n",
" 0.05062085000000138,\n",
" 0.05931074999999453,\n",
" 0.06616230000000911,\n",
" 0.07412374999999827,\n",
" 0.08495630000000176]}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"bench_times"
]
Expand Down Expand Up @@ -702,7 +554,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "nerfstudio",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -716,7 +568,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
6 changes: 6 additions & 0 deletions crates/brush-render/test_cases/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Python project to generate reference gradients with gSplat.


To use this, please install `uv`. Then:
- To generate new samples, use `uv run generate_reference.py`.
- To interact with the notebook, use `uv run --with juypter jupyter-lab` to start a kernel, and open the py file as a notebook.
Binary file modified crates/brush-render/test_cases/basic_case.safetensors
Binary file not shown.
Binary file modified crates/brush-render/test_cases/mix_case.safetensors
Binary file not shown.
Loading

0 comments on commit c2c05b6

Please sign in to comment.