Skip to content

Commit 3fe460f

Browse files
Fix backward of SH in CUDA.
1 parent 900c592 commit 3fe460f

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

diff-gaussian-rasterization/cuda_rasterizer/backward.cu

+34
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,23 @@ __device__ void computeColorFromSH_4D(int idx, int deg, int deg_t, int max_coeff
302302
float t1 = cos(2 * MY_PI * dir_t / time_duration);
303303
float dt1_dt = sin(2 * MY_PI * dir_t / time_duration) * 2 * MY_PI / time_duration;
304304

305+
dL_dsh[16] = t1 * l0m0 * dL_dRGB;
306+
dL_dsh[17] = t1 * l1m1 * dL_dRGB;
307+
dL_dsh[18] = t1 * l1m0 * dL_dRGB;
308+
dL_dsh[19] = t1 * l1p1 * dL_dRGB;
309+
dL_dsh[20] = t1 * l2m2 * dL_dRGB;
310+
dL_dsh[21] = t1 * l2m1 * dL_dRGB;
311+
dL_dsh[22] = t1 * l2m0 * dL_dRGB;
312+
dL_dsh[23] = t1 * l2p1 * dL_dRGB;
313+
dL_dsh[24] = t1 * l2p2 * dL_dRGB;
314+
dL_dsh[25] = t1 * l3m3 * dL_dRGB;
315+
dL_dsh[26] = t1 * l3m2 * dL_dRGB;
316+
dL_dsh[27] = t1 * l3m1 * dL_dRGB;
317+
dL_dsh[28] = t1 * l3m0 * dL_dRGB;
318+
dL_dsh[29] = t1 * l3p1 * dL_dRGB;
319+
dL_dsh[30] = t1 * l3p2 * dL_dRGB;
320+
dL_dsh[31] = t1 * l3p3 * dL_dRGB;
321+
305322
dRGBdt = dt1_dt * (
306323
l0m0 * sh[16] +
307324
l1m1 * sh[17] +
@@ -366,6 +383,23 @@ __device__ void computeColorFromSH_4D(int idx, int deg, int deg_t, int max_coeff
366383
float t2 = cos(2 * MY_PI * dir_t * 2 / time_duration);
367384
float dt2_dt = sin(2 * MY_PI * dir_t * 2 / time_duration) * 2 * MY_PI * 2 / time_duration;
368385

386+
dL_dsh[32] = t2 * l0m0 * dL_dRGB;
387+
dL_dsh[33] = t2 * l1m1 * dL_dRGB;
388+
dL_dsh[34] = t2 * l1m0 * dL_dRGB;
389+
dL_dsh[35] = t2 * l1p1 * dL_dRGB;
390+
dL_dsh[36] = t2 * l2m2 * dL_dRGB;
391+
dL_dsh[37] = t2 * l2m1 * dL_dRGB;
392+
dL_dsh[38] = t2 * l2m0 * dL_dRGB;
393+
dL_dsh[39] = t2 * l2p1 * dL_dRGB;
394+
dL_dsh[40] = t2 * l2p2 * dL_dRGB;
395+
dL_dsh[41] = t2 * l3m3 * dL_dRGB;
396+
dL_dsh[42] = t2 * l3m2 * dL_dRGB;
397+
dL_dsh[43] = t2 * l3m1 * dL_dRGB;
398+
dL_dsh[44] = t2 * l3m0 * dL_dRGB;
399+
dL_dsh[45] = t2 * l3p1 * dL_dRGB;
400+
dL_dsh[46] = t2 * l3p2 * dL_dRGB;
401+
dL_dsh[47] = t2 * l3p3 * dL_dRGB;
402+
369403
dRGBdt = dt2_dt * (
370404
l0m0 * sh[32] +
371405
l1m1 * sh[33] +

0 commit comments

Comments
 (0)