From b77c2486b091f23db28ec598c14cab415c4da2c6 Mon Sep 17 00:00:00 2001 From: Kevin Tan Date: Mon, 29 Jul 2024 02:41:21 -0700 Subject: [PATCH 1/2] refactor to improve efficiency --- .gitignore | 5 +- gaussian_renderer/__init__.py | 2 +- preprocess/auto_reorient.py | 4 +- render_hierarchy.py | 10 +-- requirements.txt | 6 +- scene/gaussian_model.py | 130 +++++++++++++++++--------------- scripts/full_train.py | 11 ++- submodules/gaussianhierarchy | 2 +- submodules/hierarchy-rasterizer | 2 +- submodules/simple-knn | 2 +- train_post.py | 10 +-- 11 files changed, 100 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index 6d63f0a..6ece7f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ -*.pyc -.vscode +.* +__* output build +*.egg-info tensorboard_3d screenshots diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index ecc9412..aad9024 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -220,7 +220,7 @@ def render_post( if pc.skybox_points == 0: skybox_inds = torch.Tensor([]).long() else: - skybox_inds = torch.range(pc._xyz.size(0) - pc.skybox_points, pc._xyz.size(0)-1, device="cuda").long() + skybox_inds = torch.arange(pc._xyz.size(0) - pc.skybox_points, pc._xyz.size(0), dtype=torch.long, device="cuda") means3D = torch.cat((means3D_base, means3D[skybox_inds])).contiguous() shs = torch.cat((shs_base, shs[skybox_inds])).contiguous() diff --git a/preprocess/auto_reorient.py b/preprocess/auto_reorient.py index 35d9ace..ef656bb 100644 --- a/preprocess/auto_reorient.py +++ b/preprocess/auto_reorient.py @@ -128,8 +128,8 @@ def rotate_camera(qvec, tvec, rot_matrix, upscale): right = candidates[i] - candidates[j] right /= np.linalg.norm(right) - up = torch.from_numpy(up).double() - right = torch.from_numpy(right).double() + up = torch.tensor(up, dtype=torch.float64) + right = torch.tensor(right, dtype=torch.float64) forward = torch.cross(up, right) forward /= torch.norm(forward, p=2) diff --git a/render_hierarchy.py b/render_hierarchy.py index 6030a65..6e47e29 100644 --- a/render_hierarchy.py +++ b/render_hierarchy.py @@ -33,11 +33,11 @@ def direct_collate(x): def render_set(args, scene, pipe, out_dir, tau, eval): render_path = out_dir - render_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() - parent_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() - nodes_for_render_indices = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() - interpolation_weights = torch.zeros(scene.gaussians._xyz.size(0)).float().cuda() - num_siblings = torch.zeros(scene.gaussians._xyz.size(0)).int().cuda() + render_indices = torch.zeros(scene.gaussians._xyz.size(0), dtype=torch.int, device="cuda") + parent_indices = torch.zeros(scene.gaussians._xyz.size(0), dtype=torch.int, device="cuda") + nodes_for_render_indices = torch.zeros(scene.gaussians._xyz.size(0), dtype=torch.int, device="cuda") + interpolation_weights = torch.zeros(scene.gaussians._xyz.size(0), dtype=torch.float, device="cuda") + num_siblings = torch.zeros(scene.gaussians._xyz.size(0), dtype=torch.int, device="cuda") psnr_test = 0.0 ssims = 0.0 diff --git a/requirements.txt b/requirements.txt index adc8f44..b0ffd48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,10 @@ tqdm joblib exif scikit-learn -timm==0.4.5 -opencv-python==4.9.0.80 +timm +opencv-python gradio_imageslider -gradio==4.29.0 +gradio matplotlib submodules/hierarchy-rasterizer submodules/simple-knn diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 0efd3e1..0798c98 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -155,8 +155,8 @@ def create_from_pcd( self.spatial_lr_scale = spatial_lr_scale - xyz = torch.tensor(np.asarray(pcd.points)).float().cuda() - fused_color = torch.tensor(np.asarray(pcd.colors)).float().cuda() + xyz = torch.tensor(np.asarray(pcd.points), dtype=torch.float, device="cuda") + fused_color = torch.tensor(np.asarray(pcd.colors), dtype=torch.float, device="cuda") minimum,_ = torch.min(xyz, axis=0) maximum,_ = torch.max(xyz, axis=0) @@ -170,20 +170,20 @@ def create_from_pcd( self.skybox_points = skybox_points radius = torch.linalg.norm(maximum - mean) - theta = (2.0 * torch.pi * torch.rand(skybox_points, device="cuda")).float() - phi = (torch.arccos(1.0 - 1.4 * torch.rand(skybox_points, device="cuda"))).float() - skybox_xyz = torch.zeros((skybox_points, 3)) + theta = (2.0 * torch.pi * torch.rand(skybox_points, dtype=torch.float, device="cuda")) + phi = (torch.arccos(1.0 - 1.4 * torch.rand(skybox_points, dtype=torch.float, device="cuda"))) + skybox_xyz = torch.zeros((skybox_points, 3), device="cuda") skybox_xyz[:, 0] = radius * 10 * torch.cos(theta)*torch.sin(phi) skybox_xyz[:, 1] = radius * 10 * torch.sin(theta)*torch.sin(phi) skybox_xyz[:, 2] = radius * 10 * torch.cos(phi) - skybox_xyz += mean.cpu() - xyz = torch.concat((skybox_xyz.cuda(), xyz)) - fused_color = torch.concat((torch.ones((skybox_points, 3)).cuda(), fused_color)) + skybox_xyz += mean + xyz = torch.concat((skybox_xyz, xyz)) + fused_color = torch.concat((torch.ones((skybox_points, 3), device="cuda"), fused_color)) fused_color[:skybox_points,0] *= 0.7 fused_color[:skybox_points,1] *= 0.8 fused_color[:skybox_points,2] *= 0.95 - features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2), dtype=torch.float, device="cuda") features[:, :3, 0 ] = RGB2SH(fused_color) features[:, 3:, 1:] = 0.0 @@ -206,13 +206,14 @@ def create_from_pcd( self.scaffold_points = None if scaffold_file != "": - scaffold_xyz, features_dc_scaffold, features_extra_scaffold, opacities_scaffold, scales_scaffold, rots_scaffold = self.load_ply_file(scaffold_file + "/point_cloud.ply", 1) - scaffold_xyz = torch.from_numpy(scaffold_xyz).float() - features_dc_scaffold = torch.from_numpy(features_dc_scaffold).permute(0, 2, 1).float() - features_extra_scaffold = torch.from_numpy(features_extra_scaffold).permute(0, 2, 1).float() - opacities_scaffold = torch.from_numpy(opacities_scaffold).float() - scales_scaffold = torch.from_numpy(scales_scaffold).float() - rots_scaffold = torch.from_numpy(rots_scaffold).float() + scaffold_xyz, features_dc_scaffold, features_extra_scaffold, opacities_scaffold, scales_scaffold, rots_scaffold = ( + self.load_ply_file(scaffold_file + "/point_cloud.ply", 1)) + scaffold_xyz = torch.tensor(scaffold_xyz, dtype=torch.float, device="cuda") + features_dc_scaffold = torch.tensor(features_dc_scaffold, dtype=torch.float, device="cuda").permute(0, 2, 1) + features_extra_scaffold = torch.tensor(features_extra_scaffold, dtype=torch.float, device="cuda").permute(0, 2, 1) + opacities_scaffold = torch.tensor(opacities_scaffold, dtype=torch.float, device="cuda") + scales_scaffold = torch.tensor(scales_scaffold, dtype=torch.float, device="cuda") + rots_scaffold = torch.tensor(rots_scaffold, dtype=torch.float, device="cuda") with open(scaffold_file + "/pc_info.txt") as f: skybox_points = int(f.readline()) @@ -225,10 +226,10 @@ def create_from_pcd( c = centerline.split(' ') e = extentline.split(' ') - center = torch.Tensor([float(c[0]), float(c[1]), float(c[2])]).cuda() - extent = torch.Tensor([float(e[0]), float(e[1]), float(e[2])]).cuda() + center = torch.tensor([float(c[0]), float(c[1]), float(c[2])], device="cuda") + extent = torch.tensor([float(e[0]), float(e[1]), float(e[2])], device="cuda") - distances1 = torch.abs(scaffold_xyz.cuda() - center) + distances1 = torch.abs(scaffold_xyz - center) selec = torch.logical_and( torch.max(distances1[:,0], distances1[:,1]) > 0.5 * extent[0], torch.max(distances1[:,0], distances1[:,1]) < 1.5 * extent[0]) @@ -236,15 +237,15 @@ def create_from_pcd( self.scaffold_points = selec.nonzero().size(0) - xyz = torch.concat((scaffold_xyz.cuda()[selec], xyz)) - features_dc = torch.concat((features_dc_scaffold.cuda()[selec,0:1,:], features_dc)) + xyz = torch.concat((scaffold_xyz[selec], xyz)) + features_dc = torch.concat((features_dc_scaffold[selec,0:1,:], features_dc)) - filler = torch.zeros((features_extra_scaffold.cuda()[selec,:,:].size(0), 15, 3)) - filler[:,0:3,:] = features_extra_scaffold.cuda()[selec,:,:] - features_rest = torch.concat((filler.cuda(), features_rest)) - scales = torch.concat((scales_scaffold.cuda()[selec], scales)) - rots = torch.concat((rots_scaffold.cuda()[selec], rots)) - opacities = torch.concat((opacities_scaffold.cuda()[selec], opacities)) + filler = torch.zeros((features_extra_scaffold[selec,:,:].size(0), 15, 3), device="cuda") + filler[:,0:3,:] = features_extra_scaffold[selec,:,:] + features_rest = torch.concat((filler, features_rest)) + scales = torch.concat((scales_scaffold[selec], scales)) + rots = torch.concat((rots_scaffold[selec], rots)) + opacities = torch.concat((opacities_scaffold[selec], opacities)) self._xyz = nn.Parameter(xyz.requires_grad_(True)) self._features_dc = nn.Parameter(features_dc.requires_grad_(True)) @@ -327,6 +328,13 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): self.spatial_lr_scale = spatial_lr_scale xyz, shs_all, alpha, scales, rots, nodes, boxes = load_hierarchy(path) + xyz = xyz.cuda() + shs_all = shs_all.cuda() + alpha = alpha.cuda() + scales = scales.cuda() + rots = rots.cuda() + nodes = nodes.cuda() + boxes = boxes.cuda() base = os.path.dirname(path) @@ -336,7 +344,7 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): int_val = int.from_bytes(bytes[:4], "little", signed="False") dt = np.dtype(np.int32) vals = np.frombuffer(bytes[4:], dtype=dt) - self.anchors = torch.from_numpy(vals).long().cuda() + self.anchors = torch.tensor(vals, dtype=torch.long, device="cuda") except: print("WARNING: NO ANCHORS FOUND") self.anchors = torch.Tensor([]).long() @@ -347,7 +355,9 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): with open(exposure_file, "r") as f: exposures = json.load(f) - self.pretrained_exposures = {image_name: torch.FloatTensor(exposures[image_name]).requires_grad_(False).cuda() for image_name in exposures} + self.pretrained_exposures = { + image_name: torch.tensor(exposures[image_name], dtype=torch.float, device="cuda", requires_grad=False) + for image_name in exposures} else: print(f"No exposure to be loaded at {exposure_file}") self.pretrained_exposures = None @@ -355,13 +365,14 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): #retrieve skybox self.skybox_points = 0 if scaffold_file != "": - scaffold_xyz, features_dc_scaffold, features_extra_scaffold, opacities_scaffold, scales_scaffold, rots_scaffold = self.load_ply_file(scaffold_file + "/point_cloud.ply", 1) - scaffold_xyz = torch.from_numpy(scaffold_xyz).float() - features_dc_scaffold = torch.from_numpy(features_dc_scaffold).permute(0, 2, 1).float() - features_extra_scaffold = torch.from_numpy(features_extra_scaffold).permute(0, 2, 1).float() - opacities_scaffold = torch.from_numpy(opacities_scaffold).float() - scales_scaffold = torch.from_numpy(scales_scaffold).float() - rots_scaffold = torch.from_numpy(rots_scaffold).float() + scaffold_xyz, features_dc_scaffold, features_extra_scaffold, opacities_scaffold, scales_scaffold, rots_scaffold = ( + self.load_ply_file(scaffold_file + "/point_cloud.ply", 1)) + scaffold_xyz = torch.tensor(scaffold_xyz, dtype=torch.float, device="cuda") + features_dc_scaffold = torch.tensor(features_dc_scaffold, dtype=torch.float, device="cuda").permute(0, 2, 1) + features_extra_scaffold = torch.tensor(features_extra_scaffold, dtype=torch.float, device="cuda").permute(0, 2, 1) + opacities_scaffold = torch.tensor(opacities_scaffold, dtype=torch.float, device="cuda") + scales_scaffold = torch.tensor(scales_scaffold, dtype=torch.float, device="cuda") + rots_scaffold = torch.tensor(rots_scaffold, dtype=torch.float, device="cuda") with open(scaffold_file + "/pc_info.txt") as f: skybox_points = int(f.readline()) @@ -370,24 +381,25 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): if self.skybox_points > 0: if scaffold_file != "": - skybox_xyz, features_dc_sky, features_rest_sky, opacities_sky, scales_sky, rots_sky = scaffold_xyz[:skybox_points], features_dc_scaffold[:skybox_points], features_extra_scaffold[:skybox_points], opacities_scaffold[:skybox_points], scales_scaffold[:skybox_points], rots_scaffold[:skybox_points] + skybox_xyz, features_dc_sky, features_rest_sky, opacities_sky, scales_sky, rots_sky = ( + scaffold_xyz[:skybox_points], features_dc_scaffold[:skybox_points], features_extra_scaffold[:skybox_points], opacities_scaffold[:skybox_points], scales_scaffold[:skybox_points], rots_scaffold[:skybox_points]) opacities_sky = torch.sigmoid(opacities_sky) xyz = torch.cat((xyz, skybox_xyz)) alpha = torch.cat((alpha, opacities_sky)) scales = torch.cat((scales, scales_sky)) rots = torch.cat((rots, rots_sky)) - filler = torch.zeros(features_dc_sky.size(0), 16, 3) + filler = torch.zeros(features_dc_sky.size(0), 16, 3, device="cuda") filler[:, :1, :] = features_dc_sky filler[:, 1:4, :] = features_rest_sky shs_all = torch.cat((shs_all, filler)) - self._xyz = nn.Parameter(xyz.cuda().requires_grad_(True)) - self._features_dc = nn.Parameter(shs_all.cuda()[:,:1,:].requires_grad_(True)) - self._features_rest = nn.Parameter(shs_all.cuda()[:,1:16,:].requires_grad_(True)) - self._opacity = nn.Parameter(alpha.cuda().requires_grad_(True)) - self._scaling = nn.Parameter(scales.cuda().requires_grad_(True)) - self._rotation = nn.Parameter(rots.cuda().requires_grad_(True)) + self._xyz = nn.Parameter(xyz.requires_grad_(True)) + self._features_dc = nn.Parameter(shs_all[:,:1,:].requires_grad_(True)) + self._features_rest = nn.Parameter(shs_all[:,1:16,:].requires_grad_(True)) + self._opacity = nn.Parameter(alpha.requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") self.opacity_activation = torch.abs @@ -395,25 +407,25 @@ def create_from_hier(self, path, spatial_lr_scale : float, scaffold_file : str): self.hierarchy_path = path - self.nodes = nodes.cuda() - self.boxes = boxes.cuda() + self.nodes = nodes + self.boxes = boxes def create_from_pt(self, path, spatial_lr_scale : float ): self.spatial_lr_scale = spatial_lr_scale - xyz = torch.load(path + "/done_xyz.pt") - shs_dc = torch.load(path + "/done_dc.pt") - shs_rest = torch.load(path + "/done_rest.pt") - alpha = torch.load(path + "/done_opacity.pt") - scales = torch.load(path + "/done_scaling.pt") - rots = torch.load(path + "/done_rotation.pt") - - self._xyz = nn.Parameter(xyz.cuda().requires_grad_(True)) - self._features_dc = nn.Parameter(shs_dc.cuda().requires_grad_(True)) - self._features_rest = nn.Parameter(shs_rest.cuda().requires_grad_(True)) - self._opacity = nn.Parameter(alpha.cuda().requires_grad_(True)) - self._scaling = nn.Parameter(scales.cuda().requires_grad_(True)) - self._rotation = nn.Parameter(rots.cuda().requires_grad_(True)) + xyz = torch.load(path + "/done_xyz.pt", map_location="cuda", mmap=True) + shs_dc = torch.load(path + "/done_dc.pt", map_location="cuda", mmap=True) + shs_rest = torch.load(path + "/done_rest.pt", map_location="cuda", mmap=True) + alpha = torch.load(path + "/done_opacity.pt", map_location="cuda", mmap=True) + scales = torch.load(path + "/done_scaling.pt", map_location="cuda", mmap=True) + rots = torch.load(path + "/done_rotation.pt", map_location="cuda", mmap=True) + + self._xyz = nn.Parameter(xyz.requires_grad_(True)) + self._features_dc = nn.Parameter(shs_dc.requires_grad_(True)) + self._features_rest = nn.Parameter(shs_rest.requires_grad_(True)) + self._opacity = nn.Parameter(alpha.requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def save_hier(self): diff --git a/scripts/full_train.py b/scripts/full_train.py index 53392c3..85f6875 100644 --- a/scripts/full_train.py +++ b/scripts/full_train.py @@ -55,7 +55,7 @@ def setup_dirs(images, depths, masks, colmap, chunks, output, project): parser.add_argument('--output_dir', default="") parser.add_argument('--use_slurm', action="store_true", default=False) - parser.add_argument('--skip_if_exists', action="store_true", default=False, help="Skip training a chunk if it already has a hierarchy") + parser.add_argument('--skip_if_exists', action="store_true", default=True, help="Skip training a chunk if it already has a hierarchy") parser.add_argument('--keep_running', action="store_true", default=False, help="Keep running even if a chunk processing fails") args = parser.parse_args() print(args.extra_training_args) @@ -106,7 +106,7 @@ def setup_dirs(images, depths, masks, colmap, chunks, output, project): ]) if masks_dir != "": train_coarse_args += " --alpha_masks " + masks_dir - if args.extra_training_args != "": + if args.extra_training_args != "": train_coarse_args += " " + args.extra_training_args try: @@ -156,7 +156,7 @@ def setup_dirs(images, depths, masks, colmap, chunks, output, project): source_chunk = os.path.join(chunks_dir, chunk_name) trained_chunk = os.path.join(output_dir, "trained_chunks", chunk_name) - if args.skip_if_exists and os.path.exists(os.path.join(trained_chunk, "hierarchy.hier_opt")): + if args.skip_if_exists and os.path.exists(os.path.join(trained_chunk, "hierarchy.hier")): print(f"Skipping {chunk_name}") else: ## Training can be done in parallel using slurm. @@ -200,11 +200,14 @@ def setup_dirs(images, depths, masks, colmap, chunks, output, project): if not args.keep_running: sys.exit(1) + if args.skip_if_exists and os.path.exists(os.path.join(trained_chunk, "hierarchy.hier_opt")): + print(f"Skipping {chunk_name}") + else: # Post optimization on each chunks print(f"post optimizing chunk {chunk_name}") try: subprocess.run( - post_opt_chunk_args + " -s "+ source_chunk + + post_opt_chunk_args + " -s "+ source_chunk + " --model_path " + trained_chunk + " --hierarchy " + os.path.join(trained_chunk, "hierarchy.hier"), shell=True, check=True diff --git a/submodules/gaussianhierarchy b/submodules/gaussianhierarchy index 677c855..755d4dc 160000 --- a/submodules/gaussianhierarchy +++ b/submodules/gaussianhierarchy @@ -1 +1 @@ -Subproject commit 677c8553dc64dfd62c272eca94a291a277733113 +Subproject commit 755d4dc22ab3b396f0f27b4e2d065d2343b27769 diff --git a/submodules/hierarchy-rasterizer b/submodules/hierarchy-rasterizer index 75d5138..4e76dc1 160000 --- a/submodules/hierarchy-rasterizer +++ b/submodules/hierarchy-rasterizer @@ -1 +1 @@ -Subproject commit 75d513869f2d60ba205240e1a77012127e5ea142 +Subproject commit 4e76dc16b2d5074705e14362f5b0adb7ce8135bb diff --git a/submodules/simple-knn b/submodules/simple-knn index 86710c2..b38f377 160000 --- a/submodules/simple-knn +++ b/submodules/simple-knn @@ -1 +1 @@ -Subproject commit 86710c2d4b46680c02301765dd79e465819c8f19 +Subproject commit b38f377d092cfbe9c011acf1b88efbf49ed4ca19 diff --git a/train_post.py b/train_post.py index 0f6ac97..b3e6b5d 100644 --- a/train_post.py +++ b/train_post.py @@ -56,11 +56,11 @@ def training(dataset, opt, pipe, saving_iterations, checkpoint_iterations, check limit = 0.001 - render_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() - parent_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() - nodes_for_render_indices = torch.zeros(gaussians._xyz.size(0)).int().cuda() - interpolation_weights = torch.zeros(gaussians._xyz.size(0)).float().cuda() - num_siblings = torch.zeros(gaussians._xyz.size(0)).int().cuda() + render_indices = torch.zeros(gaussians._xyz.size(0), dtype=torch.int, device="cuda") + parent_indices = torch.zeros(gaussians._xyz.size(0), dtype=torch.int, device="cuda") + nodes_for_render_indices = torch.zeros(gaussians._xyz.size(0), dtype=torch.int, device="cuda") + interpolation_weights = torch.zeros(gaussians._xyz.size(0), dtype=torch.float, device="cuda") + num_siblings = torch.zeros(gaussians._xyz.size(0), dtype=torch.int, device="cuda") to_render = 0 limmax = 0.1 From 89e1b8f364483933afe9fb477095a5adca2d4811 Mon Sep 17 00:00:00 2001 From: Kevin Tan Date: Mon, 29 Jul 2024 04:25:05 -0700 Subject: [PATCH 2/2] fix bug --- scripts/full_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/full_train.py b/scripts/full_train.py index 85f6875..0623b8f 100644 --- a/scripts/full_train.py +++ b/scripts/full_train.py @@ -81,7 +81,7 @@ def setup_dirs(images, depths, masks, colmap, chunks, output, project): print("Skipping coarse") else: if args.use_slurm: - if args.args.extra_training_args != "": + if args.extra_training_args != "": print("\nThe script does not support passing extra_training_args to slurm!!\n") submitted_jobs_ids = [] slurm_args = [