Skip to content

Commit fed1278

Browse files
committed
Suit 3D GS to DTU Dataset
1 parent d44473b commit fed1278

File tree

6 files changed

+153
-39
lines changed

6 files changed

+153
-39
lines changed

arguments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, parser):
6969

7070
class OptimizationParams(ParamGroup):
7171
def __init__(self, parser):
72-
self.iterations = 10_000
72+
self.iterations = 30_000
7373
self.position_lr_init = 0.00016
7474
self.position_lr_final = 0.0000016
7575
self.position_lr_delay_mult = 0.01

scene/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration
4545
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
4646
print("Found transforms_train.json file, assuming Blender data set!")
4747
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
48+
elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")):
49+
print("Found cameras_sphere.npz file, assuming DTU data set!")
50+
scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz")
51+
4852
else:
4953
assert False, "Could not recognize scene type!"
5054

scene/cameras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Camera(nn.Module):
1818
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
1919
image_name, uid,
20-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", gt_depth=None
20+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", fid=None
2121
):
2222
super(Camera, self).__init__()
2323

@@ -37,7 +37,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3737
self.data_device = torch.device("cuda")
3838

3939
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40-
self.depth = gt_depth.to(self.data_device) if gt_depth is not None else None
40+
self.fid = torch.LongTensor(np.array([fid])).to(self.data_device)
4141
self.image_width = self.original_image.shape[2]
4242
self.image_height = self.original_image.shape[1]
4343

scene/dataset_readers.py

Lines changed: 142 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
import numpy as np
2020
import json
2121
import imageio
22+
from glob import glob
23+
import cv2 as cv
2224
from pathlib import Path
2325
from plyfile import PlyData, PlyElement
2426
from utils.sh_utils import SH2RGB
2527
from scene.gaussian_model import BasicPointCloud
2628

29+
2730
class CameraInfo(NamedTuple):
2831
uid: int
2932
R: np.array
@@ -35,7 +38,8 @@ class CameraInfo(NamedTuple):
3538
image_name: str
3639
width: int
3740
height: int
38-
depth: np.array
41+
fid: int
42+
3943

4044
class SceneInfo(NamedTuple):
4145
point_cloud: BasicPointCloud
@@ -44,6 +48,29 @@ class SceneInfo(NamedTuple):
4448
nerf_normalization: dict
4549
ply_path: str
4650

51+
52+
def load_K_Rt_from_P(filename, P=None):
53+
if P is None:
54+
lines = open(filename).read().splitlines()
55+
if len(lines) == 4:
56+
lines = lines[1:]
57+
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
58+
P = np.asarray(lines).astype(np.float32).squeeze()
59+
60+
out = cv.decomposeProjectionMatrix(P)
61+
K = out[0]
62+
R = out[1]
63+
t = out[2]
64+
65+
K = K / K[2, 2]
66+
67+
pose = np.eye(4, dtype=np.float32)
68+
pose[:3, :3] = R.transpose()
69+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
70+
71+
return K, pose
72+
73+
4774
def getNerfppNorm(cam_info):
4875
def get_center_and_diag(cam_centers):
4976
cam_centers = np.hstack(cam_centers)
@@ -67,12 +94,13 @@ def get_center_and_diag(cam_centers):
6794

6895
return {"translate": translate, "radius": radius}
6996

97+
7098
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
7199
cam_infos = []
72100
for idx, key in enumerate(cam_extrinsics):
73101
sys.stdout.write('\r')
74102
# the exact output you're looking for:
75-
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
103+
sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
76104
sys.stdout.flush()
77105

78106
extr = cam_extrinsics[key]
@@ -84,11 +112,11 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
84112
R = np.transpose(qvec2rotmat(extr.qvec))
85113
T = np.array(extr.tvec)
86114

87-
if intr.model=="SIMPLE_PINHOLE":
115+
if intr.model == "SIMPLE_PINHOLE":
88116
focal_length_x = intr.params[0]
89117
FovY = focal2fov(focal_length_x, height)
90118
FovX = focal2fov(focal_length_x, width)
91-
elif intr.model=="PINHOLE":
119+
elif intr.model == "PINHOLE":
92120
focal_length_x = intr.params[0]
93121
focal_length_y = intr.params[1]
94122
FovY = focal2fov(focal_length_y, height)
@@ -106,6 +134,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
106134
sys.stdout.write('\n')
107135
return cam_infos
108136

137+
109138
def fetchPly(path):
110139
plydata = PlyData.read(path)
111140
vertices = plydata['vertex']
@@ -114,12 +143,13 @@ def fetchPly(path):
114143
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
115144
return BasicPointCloud(points=positions, colors=colors, normals=normals)
116145

146+
117147
def storePly(path, xyz, rgb):
118148
# Define the dtype for the structured array
119149
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
120-
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
121-
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
122-
150+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
151+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
152+
123153
normals = np.zeros_like(xyz)
124154

125155
elements = np.empty(xyz.shape[0], dtype=dtype)
@@ -131,6 +161,7 @@ def storePly(path, xyz, rgb):
131161
ply_data = PlyData([vertex_element])
132162
ply_data.write(path)
133163

164+
134165
def readColmapSceneInfo(path, images, eval, llffhold=8):
135166
try:
136167
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
@@ -144,8 +175,9 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
144175
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
145176

146177
reading_dir = "images" if images == None else images
147-
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
148-
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
178+
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics,
179+
images_folder=os.path.join(path, reading_dir))
180+
cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
149181

150182
if eval:
151183
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
@@ -178,6 +210,7 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
178210
ply_path=ply_path)
179211
return scene_info
180212

213+
181214
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
182215
cam_infos = []
183216

@@ -194,8 +227,8 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
194227
depth_name = os.path.join(path, frame["file_path"] + "_depth0000" + '.exr')
195228

196229
matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
197-
R = -np.transpose(matrix[:3,:3])
198-
R[:,0] = -R[:,0]
230+
R = -np.transpose(matrix[:3, :3])
231+
R[:, 0] = -R[:, 0]
199232
T = -matrix[:3, 3]
200233

201234
image_path = os.path.join(path, cam_name)
@@ -205,27 +238,29 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
205238

206239
im_data = np.array(image.convert("RGBA"))
207240

208-
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
241+
bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])
209242

210243
norm_data = im_data / 255.0
211-
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
212-
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
244+
arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
245+
image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")
213246

214247
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
215-
FovY = fovx
248+
FovY = fovx
216249
FovX = fovy
217250

218251
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
219-
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], depth=depth))
220-
252+
image_path=image_path, image_name=image_name, width=image.size[0],
253+
height=image.size[1], depth=depth))
254+
221255
return cam_infos
222256

257+
223258
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
224259
print("Reading Training Transforms")
225260
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
226261
print("Reading Test Transforms")
227262
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
228-
263+
229264
if not eval:
230265
train_cam_infos.extend(test_cam_infos)
231266
test_cam_infos = []
@@ -237,7 +272,7 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
237272
# Since this data set has no colmap data, we start with random points
238273
num_pts = 100_000
239274
print(f"Generating random point cloud ({num_pts})...")
240-
275+
241276
# We create random points inside the bounds of the synthetic Blender scenes
242277
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
243278
shs = np.random.random((num_pts, 3)) / 255.0
@@ -256,7 +291,93 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
256291
ply_path=ply_path)
257292
return scene_info
258293

294+
295+
def readDTUCameras(path, render_camera, object_camera):
296+
camera_dict = np.load(os.path.join(path, render_camera))
297+
images_lis = sorted(glob(os.path.join(path, 'image/*.png')))
298+
masks_lis = sorted(glob(os.path.join(path, 'mask/*.png')))
299+
n_images = len(images_lis)
300+
cam_infos = []
301+
for idx, image_path in enumerate(images_lis):
302+
image = np.array(Image.open(image_path))
303+
mask = np.array(imageio.imread(masks_lis[idx])) / 255.0
304+
image = Image.fromarray((image * mask).astype(np.uint8))
305+
world_mat = camera_dict['world_mat_%d' % idx].astype(np.float32)
306+
fid = camera_dict['fid_%d' % idx]
307+
image_name = Path(image_path).stem
308+
scale_mat = camera_dict['scale_mat_%d' % idx].astype(np.float32)
309+
P = world_mat @ scale_mat
310+
P = P[:3, :4]
311+
312+
K, pose = load_K_Rt_from_P(None, P)
313+
a = pose[0:1, :]
314+
b = pose[1:2, :]
315+
c = pose[2:3, :]
316+
317+
pose = np.concatenate([a, -c, -b, pose[3:, :]], 0)
318+
319+
S = np.eye(3)
320+
S[1, 1] = -1
321+
S[2, 2] = -1
322+
pose[1, 3] = -pose[1, 3]
323+
pose[2, 3] = -pose[2, 3]
324+
pose[:3, :3] = S @ pose[:3, :3] @ S
325+
326+
a = pose[0:1, :]
327+
b = pose[1:2, :]
328+
c = pose[2:3, :]
329+
330+
pose = np.concatenate([a, c, b, pose[3:, :]], 0)
331+
332+
pose[:, 3] *= 0.5
333+
334+
matrix = np.linalg.inv(pose)
335+
R = -np.transpose(matrix[:3, :3])
336+
R[:, 0] = -R[:, 0]
337+
T = -matrix[:3, 3]
338+
339+
FovY = focal2fov(K[0, 0], image.size[1])
340+
FovX = focal2fov(K[0, 0], image.size[0])
341+
cam_info = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
342+
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], fid=fid)
343+
cam_infos.append(cam_info)
344+
sys.stdout.write('\n')
345+
return cam_infos
346+
347+
348+
def readNeuSDTUInfo(path, render_camera, object_camera):
349+
print("Reading DTU Info")
350+
train_cam_infos = readDTUCameras(path, render_camera, object_camera)
351+
352+
nerf_normalization = getNerfppNorm(train_cam_infos)
353+
354+
ply_path = os.path.join(path, "points3d.ply")
355+
if not os.path.exists(ply_path):
356+
# Since this data set has no colmap data, we start with random points
357+
num_pts = 100_000
358+
print(f"Generating random point cloud ({num_pts})...")
359+
360+
# We create random points inside the bounds of the synthetic Blender scenes
361+
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
362+
shs = np.random.random((num_pts, 3)) / 255.0
363+
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
364+
365+
storePly(ply_path, xyz, SH2RGB(shs) * 255)
366+
try:
367+
pcd = fetchPly(ply_path)
368+
except:
369+
pcd = None
370+
371+
scene_info = SceneInfo(point_cloud=pcd,
372+
train_cameras=train_cam_infos,
373+
test_cameras=[],
374+
nerf_normalization=nerf_normalization,
375+
ply_path=ply_path)
376+
return scene_info
377+
378+
259379
sceneLoadTypeCallbacks = {
260380
"Colmap": readColmapSceneInfo,
261-
"Blender" : readNerfSyntheticInfo
262-
}
381+
"Blender": readNerfSyntheticInfo,
382+
"DTU": readNeuSDTUInfo,
383+
}

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
7171
viewpoint_stack = scene.getTrainCameras().copy()
7272

7373
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
74-
gt_depth = viewpoint_cam.depth
74+
fid = viewpoint_cam.fid
7575
# Render
7676
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
7777
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
@@ -81,8 +81,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
8181
gt_image = viewpoint_cam.original_image.cuda()
8282
Ll1 = l1_loss(image, gt_image)
8383
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
84-
depth_loss = l1_loss(depth, gt_depth) * 0.1
85-
loss = loss + depth_loss
84+
# depth_loss = l1_loss(depth, gt_depth) * 0.1
85+
# loss = loss + depth_loss
8686
loss.backward()
8787

8888
iter_end.record()

utils/camera_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale):
3939
resolution = (int(orig_w / scale), int(orig_h / scale))
4040

4141
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42-
if cam_info.depth is not None:
43-
resized_depth_rgb = ArrayToTorch(cam_info.depth, resolution)
44-
else:
45-
resized_depth_rgb = None
4642

4743
gt_image = resized_image_rgb[:3, ...]
48-
if resized_depth_rgb is not None:
49-
depth_mask = resized_depth_rgb[0, ...] > 60000
50-
gt_depth = resized_depth_rgb[0, ...]
51-
gt_depth[depth_mask] = 0
52-
else:
53-
gt_depth = None
54-
5544
loaded_mask = None
5645

5746
if resized_image_rgb.shape[1] == 4:
@@ -60,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
6049
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
6150
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
6251
image=gt_image, gt_alpha_mask=loaded_mask,
63-
image_name=cam_info.image_name, uid=id, data_device=args.data_device, gt_depth=gt_depth)
52+
image_name=cam_info.image_name, uid=id, data_device=args.data_device, fid=cam_info.fid)
6453

6554
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
6655
camera_list = []

0 commit comments

Comments
 (0)