diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh index 274185842..3eb2e266f 100644 --- a/examples/benchmarks/compression/mcmc.sh +++ b/examples/benchmarks/compression/mcmc.sh @@ -1,4 +1,4 @@ -SCENE_DIR="data/360_v2" +SCENE_DIR="../data/360_v2" # eval all 9 scenes for benchmarking SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers" @@ -21,22 +21,16 @@ CAP_MAX=1000000 for SCENE in $SCENE_LIST; do - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor -1 \ --strategy.cap-max $CAP_MAX \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # eval: use vgg for lpips to align with other benchmarks - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor -1\ --strategy.cap-max $CAP_MAX \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ @@ -52,4 +46,4 @@ then python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST else echo "zip command not found, skipping zipping" -fi \ No newline at end of file +fi diff --git a/examples/benchmarks/compression/mcmc_db.sh b/examples/benchmarks/compression/mcmc_db.sh new file mode 100644 index 000000000..5800ec7ee --- /dev/null +++ b/examples/benchmarks/compression/mcmc_db.sh @@ -0,0 +1,49 @@ +SCENE_DIR="data/db" +# eval all 2 scenes for benchmarking +SCENE_LIST="playroom drjohnson" + +# # 0.36M GSs +# RESULT_DIR="results/benchmark_db_mcmc_0_36M_png_compression" +# CAP_MAX=360000 + +# # 0.49M GSs +# RESULT_DIR="results/benchmark_db_mcmc_0_49M_png_compression" +# CAP_MAX=490000 + +# 1M GSs +RESULT_DIR="results/benchmark_db_mcmc_1M_png_compression" +CAP_MAX=1000000 + +# # 4M GSs +# RESULT_DIR="results/benchmark_db_mcmc_4M_png_compression" +# CAP_MAX=4000000 + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train without eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor -1 \ + --strategy.cap-max $CAP_MAX \ + --opacity_reg 0.001 \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # eval: use vgg for lpips to align with other benchmarks + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor -1 \ + --strategy.cap-max $CAP_MAX \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --lpips_net vgg \ + --compression png \ + --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt +done + +# Zip the compressed files and summarize the stats +if command -v zip &> /dev/null +then + echo "Zipping results" + python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --scenes $SCENE_LIST +else + echo "zip command not found, skipping zipping" +fi diff --git a/examples/benchmarks/compression/mcmc_syn.sh b/examples/benchmarks/compression/mcmc_syn.sh new file mode 100644 index 000000000..410df7c3d --- /dev/null +++ b/examples/benchmarks/compression/mcmc_syn.sh @@ -0,0 +1,47 @@ +SCENE_DIR="data/nerf_synthetic" +SCENE_LIST="chair drums ficus hotdog lego materials mic ship" + +# # 0.36M GSs +# RESULT_DIR="results/benchmark_syn_mcmc_0_36M_png_compression" +# CAP_MAX=360000 + +# # 0.49M GSs +# RESULT_DIR="results/benchmark_syn_mcmc_0_49M_png_compression" +# CAP_MAX=490000 + +# 1M GSs +RESULT_DIR="results/benchmark_syn_mcmc_1M_png_compression" +CAP_MAX=1000000 + +# # 4M GSs +# RESULT_DIR="results/benchmark_syn_mcmc_4M_png_compression" +# CAP_MAX=4000000 + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train without eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \ + --strategy.cap-max $CAP_MAX \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # eval: use vgg for lpips to align with other benchmarks + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \ + --strategy.cap-max $CAP_MAX \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --lpips_net vgg \ + --compression png \ + --ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt +done + +# Zip the compressed files and summarize the stats +if command -v zip &> /dev/null +then + echo "Zipping results" + python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST +else + echo "zip command not found, skipping zipping" +fi diff --git a/examples/benchmarks/compression/mcmc_tt.sh b/examples/benchmarks/compression/mcmc_tt.sh index 34d9f5ea7..754ac3910 100644 --- a/examples/benchmarks/compression/mcmc_tt.sh +++ b/examples/benchmarks/compression/mcmc_tt.sh @@ -23,13 +23,13 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor -1 \ --strategy.cap-max $CAP_MAX \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # eval: use vgg for lpips to align with other benchmarks - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor -1 \ --strategy.cap-max $CAP_MAX \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ @@ -45,4 +45,4 @@ then python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST else echo "zip command not found, skipping zipping" -fi \ No newline at end of file +fi diff --git a/examples/benchmarks/compression/results/DeepBlending.csv b/examples/benchmarks/compression/results/DeepBlending.csv new file mode 100644 index 000000000..0d2b85bb3 --- /dev/null +++ b/examples/benchmarks/compression/results/DeepBlending.csv @@ -0,0 +1,5 @@ +Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians +,29.417935371398926,0.8983585238456726,0.2698713690042496,6877574.0,360000 +,29.31465244293213,0.8991637229919434,0.2642294019460678,8770896.0,490000 +-1.00M,29.8878116607666,0.9035206139087677,0.25052621215581894,16123593.0,1000000 +,29.644909858703613,0.903471440076828,0.23753593116998672,58272202.5,4000000 diff --git a/examples/benchmarks/compression/results/MipNeRF360.csv b/examples/benchmarks/compression/results/MipNeRF360.csv index faf69be9f..816a877ca 100644 --- a/examples/benchmarks/compression/results/MipNeRF360.csv +++ b/examples/benchmarks/compression/results/MipNeRF360.csv @@ -1,5 +1,5 @@ Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians -,26.64,0.788,0.270,6916294,360000 -,26.88,0.796,0.256,8796870,490000 --1.00M,27.29,0.811,0.229,16038022,1000000 -,27.70,0.825,0.197,57812682,4000000 \ No newline at end of file +,26.694747077094185,0.786249836285909,0.27215222186512417,6912028.777777778,360000 +,26.92844581604004,0.7948330309655931,0.25781087908479905,8767350.777777778,490000 +-1.00M,27.329831653171116,0.8091025948524475,0.23095709085464478,16028623.0,1000000 +,27.795140160454643,0.8232156700558133,0.198530751797888,57659358.222222224,4000000 diff --git a/examples/benchmarks/compression/results/SyntheticNeRF.csv b/examples/benchmarks/compression/results/SyntheticNeRF.csv new file mode 100644 index 000000000..2391d390b --- /dev/null +++ b/examples/benchmarks/compression/results/SyntheticNeRF.csv @@ -0,0 +1,5 @@ +Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians +,33.50586247444153,0.9701051115989685,0.030247553484514356,7316613.5,360000 +,33.5969717502594,0.9701937362551689,0.029584018629975617,9346813.75,490000 +-1.00M,33.724597454071045,0.9700975641608238,0.029100348940119147,17143747.375,1000000 +,33.32782459259033,0.968658909201622,0.029677038837689906,61105994.625,4000000 diff --git a/examples/benchmarks/compression/results/TanksAndTemples.csv b/examples/benchmarks/compression/results/TanksAndTemples.csv index 5845808d9..9b7477669 100644 --- a/examples/benchmarks/compression/results/TanksAndTemples.csv +++ b/examples/benchmarks/compression/results/TanksAndTemples.csv @@ -1,5 +1,5 @@ Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians -,23.54,0.838,0.200,6875669,360000 -,23.62,0.845,0.188,8728572,490000 --1.00M,24.03,0.857,0.163,16100628,1000000 -,24.47,0.872,0.132,58239022,4000000 +,23.484140396118164,0.8359003365039825,0.20022188872098923,6814856.5,360000 +,23.68420124053955,0.8424293696880341,0.18749213218688965,8710374.5,490000 +-1.00M,23.996936798095703,0.855468362569809,0.16304801404476166,16065561.5,1000000 +,24.45703887939453,0.8690102994441986,0.13164417818188667,58291533.5,4000000 diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 11ad2a4b2..2e8e061dc 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -28,7 +28,7 @@ def _get_rel_paths(path_dir: str) -> List[str]: return paths -def _resize_image_folder(image_dir: str, resized_dir: str, factor: int) -> str: +def _resize_image_folder(image_dir: str, resized_dir: str, factor: float) -> str: """Resize image folder.""" print(f"Downscaling images by {factor}x from {image_dir} to {resized_dir}.") os.makedirs(resized_dir, exist_ok=True) @@ -59,12 +59,11 @@ class Parser: def __init__( self, data_dir: str, - factor: int = 1, + factor: int = -1, normalize: bool = False, test_every: int = 8, ): self.data_dir = data_dir - self.factor = factor self.normalize = normalize self.test_every = test_every @@ -104,7 +103,6 @@ def __init__( cam = manager.cameras[camera_id] fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) - K[:2, :] /= factor Ks_dict[camera_id] = K # Get distortion parameters. @@ -132,7 +130,10 @@ def __init__( ), f"Only perspective and fisheye cameras are supported, got {type_}" params_dict[camera_id] = params - imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) + imsize_dict[camera_id] = ( + cam.width // abs(factor), + cam.height // abs(factor), + ) mask_dict[camera_id] = None print( f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." @@ -195,9 +196,36 @@ def __init__( colmap_image_dir, image_dir + "_png", factor=factor ) image_files = sorted(_get_rel_paths(image_dir)) + colmap_to_image = dict(zip(colmap_files, image_files)) image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] + # load one image to check the size. + actual_image = imageio.imread(image_paths[0])[..., :3] + actual_height, actual_width = actual_image.shape[:2] + + # need to check image resolution, side length > 1600 should be downscaled + # based on https://github.com/graphdeco-inria/gaussian-splatting/blob/54c035f7834b564019656c3e3fcc3646292f727d/utils/camera_utils.py#L50 + max_side = max(actual_width, actual_height) + global_down = max_side / 1600.0 + + if factor == -1 and max_side > 1600: + print( + "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " + "If this is not desired, please explicitly specify '--data_factor' as 1" + ) + factor = global_down + image_dir = _resize_image_folder( + colmap_image_dir, image_dir + "_1600px", factor=factor + ) + image_files = sorted(_get_rel_paths(image_dir)) + colmap_to_image = dict(zip(colmap_files, image_files)) + image_paths = [ + os.path.join(image_dir, colmap_to_image[f]) for f in image_names + ] + + self.factor = factor + # 3D points and {image_name -> [point_idx]} points = manager.points3D.astype(np.float32) points_err = manager.point3D_errors.astype(np.float32) @@ -242,18 +270,18 @@ def __init__( self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] self.transform = transform # np.ndarray, (4, 4) - # load one image to check the size. In the case of tanksandtemples dataset, the - # intrinsics stored in COLMAP corresponds to 2x upsampled images. - actual_image = imageio.imread(self.image_paths[0])[..., :3] - actual_height, actual_width = actual_image.shape[:2] colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] s_height, s_width = actual_height / colmap_height, actual_width / colmap_width for camera_id, K in self.Ks_dict.items(): + K[:2, :] /= factor K[0, :] *= s_width K[1, :] *= s_height self.Ks_dict[camera_id] = K width, height = self.imsize_dict[camera_id] - self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) + self.imsize_dict[camera_id] = ( + int(width * s_width / global_down), + int(height * s_height / global_down), + ) # undistortion self.mapx_dict = dict() diff --git a/examples/datasets/synthetic.py b/examples/datasets/synthetic.py new file mode 100644 index 000000000..c0faf3de6 --- /dev/null +++ b/examples/datasets/synthetic.py @@ -0,0 +1,282 @@ +import os +import json +from pathlib import Path + +import imageio.v2 as imageio +from typing import Any, Dict, Optional +from PIL import Image +import numpy as np +import torch + +from .normalize import ( + similarity_from_cameras, + transform_cameras, +) + + +def fov2focal(fov, pixels): + return pixels / (2 * np.tan(fov / 2)) + + +def load_synthetic(data_dir, file, factor: int, id_offset): + # Built from INRIA's Gaussian Splatting Code + # https://github.com/graphdeco-inria/gaussian-splatting/blob/8a70a8cd6f0d9c0a14f564844ead2d1147d5a7ac/scene/dataset_readers.py#L179 + camtoworlds = [] + camera_ids = [] + image_names = [] + image_paths = [] + Ks_dict = dict() + params_dict = dict() + imsize_dict = dict() # width, height + mask_dict = dict() + with open(os.path.join(data_dir, file)) as json_file: + contents = json.load(json_file) + FOVX = contents["camera_angle_x"] + + frames = contents["frames"] + + for idx, frame in enumerate(frames): + image_path = os.path.join(data_dir, frame["file_path"] + ".png") + camera_id = idx + id_offset + + # NeRF 'transform_matrix' is a camera-to-world transform + c2w = np.array(frame["transform_matrix"]) + # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) + c2w[:3, 1:3] *= -1 + + image_name = Path(image_path).stem + image = Image.open(image_path) + + fx = fov2focal(FOVX, image.width) + fy = fov2focal(FOVX, image.height) + cx, cy = 0.5 * image.width, 0.5 * image.height + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=float) + K[:2, :] /= factor + Ks_dict[camera_id] = K + + camera_ids.append(camera_id) + camtoworlds.append(c2w) + # assume no distortion + params_dict[camera_id] = np.empty(0, dtype=np.float32) + imsize_dict[camera_id] = (image.width // factor, image.height // factor) + mask_dict[camera_id] = None + image_names.append(image_name) + image_paths.append(image_path) + return ( + camera_ids, + camtoworlds, + params_dict, + imsize_dict, + mask_dict, + image_names, + image_paths, + Ks_dict, + ) + + +class Parser: + """synthetic parser.""" + + def __init__( + self, + data_dir: str, + factor: int = 1, + normalize: bool = False, + test_every: int = 8, + test_max_res: int = 1600, # max image side length in pixel for test split + ): + self.data_dir = data_dir + self.factor = factor + self.normalize = normalize + # test_every is not needed, as we have a dedicated test set + self.test_every = test_every + + # Load camera-to-world matrices. + camtoworlds = [] + camera_ids = [] + image_names = [] + image_paths = [] + Ks_dict = dict() + params_dict = dict() + imsize_dict = dict() # width, height + mask_dict = dict() + bottom = np.array([0, 0, 0, 1]).reshape(1, 4) + + # load training data + ( + camera_ids, + camtoworlds, + params_dict, + imsize_dict, + mask_dict, + image_names, + image_paths, + Ks_dict, + ) = load_synthetic(data_dir, "transforms_train.json", factor, 0) + + self.train_indices = np.arange(len(image_names)) + train_camera_id_len = len(camera_ids) + + # load test data + ( + camera_ids_testdata, + camtoworlds_testdata, + params_dict_testdata, + imsize_dict_testdata, + mask_dict_testdata, + image_names_testdata, + image_paths_testdata, + Ks_dict_testdata, + ) = load_synthetic( + data_dir, "transforms_train.json", factor, train_camera_id_len + ) + + # join data + camera_ids += camera_ids_testdata + camtoworlds += camtoworlds + params_dict.update(params_dict_testdata) + imsize_dict.update(imsize_dict_testdata) + mask_dict.update(mask_dict_testdata) + image_names += image_names_testdata + image_paths += image_paths_testdata + Ks_dict.update(Ks_dict_testdata) + + self.test_indices = np.arange(len(self.train_indices), len(image_names)) + + camtoworlds = np.array(camtoworlds) + + print( + f"[Parser] {len(image_names)} images, taken by {len(set(camera_ids))} cameras." + ) + + if len(image_names) == 0: + raise ValueError("No images found.") + + # Load extended metadata. Used by Bilarf dataset. + self.extconf = { + "spiral_radius_scale": 1.0, + "no_factor_suffix": False, + } + extconf_file = os.path.join(data_dir, "ext_metadata.json") + if os.path.exists(extconf_file): + with open(extconf_file) as f: + self.extconf.update(json.load(f)) + + # Load bounds if possible (only used in forward facing scenes). + self.bounds = np.array([0.01, 1.0]) + posefile = os.path.join(data_dir, "poses_bounds.npy") + if os.path.exists(posefile): + self.bounds = np.load(posefile)[:, -2:] + + # 3D points + points = None + points_err = None + points_rgb = None + point_indices = dict() + + # Normalize the world space. + if normalize: + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + transform = T1 + else: + transform = np.eye(4) + + self.image_names = image_names # List[str], (num_images,) + self.image_paths = image_paths # List[str], (num_images,) + self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.camera_ids = camera_ids # List[int], (num_images,) + self.Ks_dict = Ks_dict # Dict of camera_id -> K + self.params_dict = params_dict # Dict of camera_id -> params + self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) + self.mask_dict = mask_dict # Dict of camera_id -> mask + self.points = points # np.ndarray, (num_points, 3) + self.points_err = points_err # np.ndarray, (num_points,) + self.points_rgb = points_rgb # np.ndarray, (num_points, 3) + self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] + self.transform = transform # np.ndarray, (4, 4) + + # load one image to check the size. In the case of tanksandtemples dataset, the + # intrinsics stored in COLMAP corresponds to 2x upsampled images. + actual_image = imageio.imread(self.image_paths[0])[..., :3] + actual_height, actual_width = actual_image.shape[:2] + + colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] + s_height, s_width = actual_height / colmap_height, actual_width / colmap_width + for camera_id, K in self.Ks_dict.items(): + K[0, :] *= s_width + K[1, :] *= s_height + self.Ks_dict[camera_id] = K + width, height = self.imsize_dict[camera_id] + self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) + + # size of the scene measured by cameras + camera_locations = camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = np.max(dists) + + +class Dataset: + """A simple dataset class.""" + + def __init__( + self, + parser: Parser, + split: str = "train", + patch_size: Optional[int] = None, + load_depths: bool = False, + white_background: bool = True, + ): + self.parser = parser + self.split = split + self.patch_size = patch_size + self.load_depths = load_depths + self.white_background = white_background + if split == "train": + self.indices = self.parser.train_indices + else: + self.indices = self.parser.test_indices + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item: int) -> Dict[str, Any]: + index = self.indices[item] + camera_id = self.parser.camera_ids[index] + image = Image.open(self.parser.image_paths[index]) + camtoworlds = self.parser.camtoworlds[index] + mask = self.parser.mask_dict[camera_id] + K = self.parser.Ks_dict[camera_id].copy() + + image = np.array(image) + bg = ( + np.array([255.0, 255.0, 255.0]) + if self.white_background + else np.array([0.0, 0.0, 0.0]) + ) + + image = image[:, :, :3] * (image[:, :, 3:4] / 255.0) + bg * ( + 1 - (image[:, :, 3:4] / 255.0) + ) + image = image[..., :3] + + if self.patch_size is not None: + # Random crop. + h, w = image.shape[:2] + x = np.random.randint(0, max(w - self.patch_size, 1)) + y = np.random.randint(0, max(h - self.patch_size, 1)) + image = image[y : y + self.patch_size, x : x + self.patch_size] + K[0, 2] -= x + K[1, 2] -= y + + data = { + "K": torch.from_numpy(K).float(), + "camtoworld": torch.from_numpy(camtoworlds).float(), + "image": torch.from_numpy(image).float(), + "image_id": item, # the index of the image in the dataset + } + if mask is not None: + data["mask"] = torch.from_numpy(mask).bool() + + return data diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ca9271e81..5a866deba 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -58,7 +58,7 @@ class Config: # Path to the Mip-NeRF 360 dataset data_dir: str = "data/360_v2/garden" # Downsample factor for the dataset - data_factor: int = 4 + data_factor: int = -1 # Directory to save results result_dir: str = "results/garden" # Every N images there is a test image diff --git a/examples/simple_trainer_synthetic.py b/examples/simple_trainer_synthetic.py new file mode 100644 index 000000000..13ac488c3 --- /dev/null +++ b/examples/simple_trainer_synthetic.py @@ -0,0 +1,1108 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from datasets.synthetic import Dataset +from datasets.synthetic import Parser +from datasets.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from lib_bilagrid import ( + BilateralGrid, + slice, + color_correct, + total_variation_loss, +) + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.optimizers import SelectiveAdam + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 1 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "random" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use visible adam from Taming 3DGS. (experimental) + visible_adam: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + visible_adam: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + rgbs = rgbs[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizer_class = None + if sparse_grad: + optimizer_class = torch.optim.SparseAdam + elif visible_adam: + optimizer_class = SelectiveAdam + else: + optimizer_class = torch.optim.Adam + optimizers = { + name: optimizer_class( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # use white background + self.use_white_background = True + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + white_background=self.use_white_background, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + visible_adam=cfg.visible_adam, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + world_rank=world_rank, + world_size=world_size, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) + + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + if self.use_white_background: + background = torch.ones(1, colors.shape[-1], device=self.device) + # background = torch.ones( + # camtoworlds.shape[0], colors.shape[-1], device="cuda" + # ) + else: + background = None + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + backgrounds=background, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + **kwargs, + ) + if masks is not None: + render_colors[~masks] = 0 + + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + masks=masks, + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + # regularizations + if cfg.opacity_reg > 0.0: + loss = ( + loss + + cfg.opacity_reg + * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + ) + if cfg.scale_reg > 0.0: + loss = ( + loss + + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() + ) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + # write images (gt and render) + # if world_rank == 0 and step % 800 == 0: + # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + # canvas = canvas.reshape(-1, *canvas.shape[2:]) + # imageio.imwrite( + # f"{self.render_dir}/train_rank{self.world_rank}.png", + # (canvas * 255).astype(np.uint8), + # ) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + if cfg.visible_adam: + gaussian_cnt = self.splats.means.shape[0] + if cfg.packed: + visibility_mask = torch.zeros_like( + self.splats["opacities"], dtype=bool + ) + visibility_mask.scatter_(0, info["gaussian_ids"], 1) + else: + visibility_mask = (info["radii"] > 0).any(0) + + # optimize + for optimizer in self.optimizers.values(): + if cfg.visible_adam: + optimizer.step(visibility_mask) + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # Run post-backward steps after backward and optimizer + if isinstance(self.cfg.strategy, DefaultStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_traj(step) + + # run compression + if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: + self.run_compression(step=step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + masks = data["mask"].to(device) if "mask" in data else None + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + masks=masks, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, render_alphas, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + step = ckpts[0]["step"] + runner.eval(step=step) + runner.render_traj(step=step) + if cfg.compression is not None: + runner.run_compression(step=step) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + Config( + init_opa=0.5, + init_scale=0.1, + opacity_reg=0.01, + scale_reg=0.01, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True)