-
Notifications
You must be signed in to change notification settings - Fork 91
/
fps_benchmark_demo.py
89 lines (76 loc) · 3.48 KB
/
fps_benchmark_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from pathlib import Path
from tqdm import tqdm
import numpy as np
from dataclasses import dataclass
import matplotlib.pyplot as plt
from gaussian_renderer import render
from utils.general_utils import safe_state
from utils.viewer_utils import OrbitCamera
from argparse import ArgumentParser
from gaussian_renderer import GaussianModel, FlameGaussianModel
@dataclass
class PipelineConfig:
debug: bool = False
compute_cov3D_python: bool = False
convert_SHs_python: bool = False
def prepare_camera(width, height):
cam = OrbitCamera(width, height, r=1, fovy=20, convention='opencv')
@dataclass
class Cam:
FoVx = float(np.radians(cam.fovx))
FoVy = float(np.radians(cam.fovy))
image_height = cam.image_height
image_width = cam.image_width
world_view_transform = torch.tensor(cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
full_proj_transform = torch.tensor(cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
camera_center = torch.tensor(cam.pose[:3, 3]).cuda()
return Cam
def render_sets(pipeline : PipelineConfig, point_path, sh_degree, height, width, n_iter, vis=False):
with torch.no_grad():
# init gaussians
if (Path(point_path).parent / "flame_param.npz").exists():
gaussians = FlameGaussianModel(sh_degree)
else:
gaussians = GaussianModel(sh_degree)
# load gaussians
assert point_path is not None
if point_path.exists():
gaussians.load_ply(point_path, has_target=False)
else:
raise FileNotFoundError(f'{point_path} does not exist.')
background = torch.tensor([1,1,1], dtype=torch.float32, device="cuda")
cam = prepare_camera(width, height)
for i in range(3):
print(f"\nRound {i+1}")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in tqdm(range(n_iter)):
if gaussians.binding != None:
gaussians.select_mesh_by_timestep(0)
rendering = render(cam, gaussians, pipeline, background)["render"]
end.record()
torch.cuda.synchronize()
elapsed_time = start.elapsed_time(end) / 1000
print(f"Rendering {n_iter} images took {elapsed_time:.2f} s")
print(f"FPS: {n_iter / elapsed_time:.2f}")
if vis:
print("\nVisualizing the rendering result")
plt.imshow(rendering.permute(1, 2, 0).clip(0, 1).cpu().numpy())
plt.show()
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
parser.add_argument("--point_path", default="media/306/point_cloud.ply", type=Path)
parser.add_argument("--sh_degree", default=3, type=int)
parser.add_argument("--height", default=802, type=int)
parser.add_argument("--width", default=550, type=int)
parser.add_argument("--n_iter", default=500, type=int)
parser.add_argument("--vis", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = parser.parse_args()
print("Rendering " + str(args.point_path))
# Initialize system state (RNG)
safe_state(args.quiet)
render_sets(PipelineConfig(), args.point_path, args.sh_degree, args.height, args.width, args.n_iter, args.vis)