Skip to content

Commit

Permalink
Generate transforms.json at download
Browse files Browse the repository at this point in the history
  • Loading branch information
VasuAgrawal committed Dec 12, 2023
1 parent 0971dba commit a1996dd
Showing 1 changed file with 136 additions and 19 deletions.
155 changes: 136 additions & 19 deletions nerfstudio/scripts/downloads/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import copy
import json
import os
import shutil
import tarfile
Expand All @@ -27,6 +28,7 @@

import awscli.clidriver
import gdown
import numpy as np
import tyro
from typing_extensions import Annotated

Expand Down Expand Up @@ -477,14 +479,15 @@ class EyefulTowerResolutionMetadata:
folder_name: str
width: int
height: int
extension: str


eyefultower_resolutions = {
"all": None,
"jpeg_2k": EyefulTowerResolutionMetadata("images-jpeg-2k", 1368, 2048),
"jpeg_4k": EyefulTowerResolutionMetadata("images-jpeg-4k", 2736, 4096),
"jpeg_8k": EyefulTowerResolutionMetadata("images-jpeg", 5784, 8660),
"exr_2k": EyefulTowerResolutionMetadata("images-2k", 1368, 2048),
"jpeg_2k": EyefulTowerResolutionMetadata("images-jpeg-2k", 1368, 2048, "jpg"),
"jpeg_4k": EyefulTowerResolutionMetadata("images-jpeg-4k", 2736, 4096, "jpg"),
"jpeg_8k": EyefulTowerResolutionMetadata("images-jpeg", 5784, 8660, "jpg"),
"exr_2k": EyefulTowerResolutionMetadata("images-2k", 1368, 2048, "exr"),
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -559,6 +562,89 @@ def scale_metashape_transform(xml_tree: ET.ElementTree, target_width: int, targe

return transformed

def convert_cameras_to_nerfstudio_transforms(
self, cameras: dict, target_width: int, target_height: int, extension: str
):
output = {}

distortion_models = [c["distortionModel"] for c in cameras["KRT"]]
distortion_model = list(set(distortion_models))
assert len(distortion_model) == 1
distortion_model = distortion_model[0]
if distortion_model == "RadialAndTangential":
output["camera_model"] = "OPENCV"
elif distortion_model == "Fisheye":
output["camera_model"] = "OPENCV_FISHEYE"
else:
raise NotImplementedError(f"Camera model {distortion_model} not implemented")

frames = []
for camera in cameras["KRT"]:
frame = {}
# TODO EXR
frame["file_path"] = camera["cameraId"] + f".{extension}"

original_width = camera["width"]
original_height = camera["height"]
if original_width > original_height:
target_width, target_height = max(target_width, target_height), min(target_width, target_height)
else:
target_height, target_width = max(target_width, target_height), min(target_width, target_height)
x_scale = target_width / original_width
y_scale = target_height / original_height

frame["w"] = target_width
frame["h"] = target_height
K = np.array(camera["K"]).T # Data stored as column-major
frame["fl_x"] = K[0][0] * x_scale
frame["fl_y"] = K[1][1] * y_scale
frame["cx"] = K[0][2] * x_scale
frame["cy"] = K[1][2] * y_scale

if distortion_model == "RadialAndTangential":
# pinhole: [k1, k2, p1, p2, k3]
frame["k1"] = camera["distortion"][0]
frame["k2"] = camera["distortion"][1]
frame["k3"] = camera["distortion"][4]
frame["k4"] = 0.0
frame["p1"] = camera["distortion"][2]
frame["p2"] = camera["distortion"][3]
elif distortion_model == "Fisheye":
# fisheye: [k1, k2, k3, _, _, _, p1, p2]
frame["k1"] = camera["distortion"][0]
frame["k2"] = camera["distortion"][1]
frame["k3"] = camera["distortion"][2]
frame["p1"] = camera["distortion"][6]
frame["p2"] = camera["distortion"][7]
else:
raise NotImplementedError("This shouldn't happen")

T = np.array(camera["T"]).T # Data stored as column-major
T = np.linalg.inv(T)
T = T[[2, 0, 1, 3], :]
T[:, 1:3] *= -1
frame["transform_matrix"] = T.tolist()

frames.append(frame)

frames = sorted(frames, key=lambda f: f["file_path"])

output["frames"] = frames
return output

def subsample_nerfstudio_transforms(self, transforms: dict, n: int):
target = min(len(transforms["frames"]), n)
indices = np.round(np.linspace(0, len(transforms["frames"]) - 1, target)).astype(int)

frames = []
for i in indices:
frames.append(transforms["frames"][i])

output = copy.deepcopy(transforms)
output["frames"] = frames

return output

def download(self, save_dir: Path):
if len(self.capture_name) == 0:
self.capture_name = ("riverview",)
Expand Down Expand Up @@ -623,21 +709,52 @@ def download(self, save_dir: Path):
xml_input_path = output_path / "cameras.xml"
if not xml_input_path.exists:
print(" WARNING: cameras.xml not found. Scaled cameras.xml will not be generated.")
continue

tree = ET.parse(output_path / "cameras.xml")

for resolution in resolutions:
metadata = eyefultower_resolutions[resolution]
xml_output_path = output_path / metadata.folder_name / "cameras.xml"
print(
f" Generating cameras.xml for '{resolution}' to {xml_output_path.resolve()} ... ",
end=" ",
flush=True,
)
scaled_tree = self.scale_metashape_transform(tree, metadata.width, metadata.height)
scaled_tree.write(xml_output_path)
print("done!")
else:
tree = ET.parse(output_path / "cameras.xml")

for resolution in resolutions:
metadata = eyefultower_resolutions[resolution]
xml_output_path = output_path / metadata.folder_name / "cameras.xml"
print(
f" Generating cameras.xml for '{resolution}' to {xml_output_path.resolve()} ... ",
end=" ",
flush=True,
)
scaled_tree = self.scale_metashape_transform(tree, metadata.width, metadata.height)
scaled_tree.write(xml_output_path)
print("done!")

json_input_path = output_path / "cameras.json"
if not json_input_path.exists:
print(" WARNING: cameras.json not found. transforms.json will not be generated.")
else:
with open(json_input_path, "r") as f:
cameras = json.load(f)

for resolution in resolutions:
metadata = eyefultower_resolutions[resolution]
json_output_path = output_path / metadata.folder_name / "transforms.json"
print(
f" Generating transforms.json for '{resolution}' to {json_output_path.resolve()} ... ",
end=" ",
flush=True,
)
transforms = self.convert_cameras_to_nerfstudio_transforms(
cameras, metadata.width, metadata.height, metadata.extension
)

with open(json_output_path, "w", encoding="utf8") as f:
json.dump(transforms, f, indent=4)

for count, name in [
(300, "transforms_300.json"),
(int(len(cameras["KRT"]) // 2), "transforms_half.json"),
]:
subsampled = self.subsample_nerfstudio_transforms(transforms, count)
with open(json_output_path.with_name(name), "w", encoding="utf8") as f:
json.dump(subsampled, f, indent=4)

print("done!")


Commands = Union[
Expand Down

0 comments on commit a1996dd

Please sign in to comment.