Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support downscale_factor for colmap dataset #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 190 additions & 43 deletions dn_splatter/data/coolermap_dataparser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import math
import glob
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Literal, Optional, Type
from typing import Literal, Optional, Type, List

import cv2
import sys
import numpy as np
from PIL import Image
import open3d as o3d
import torch
from dn_splatter.scripts.align_depth import ColmapToAlignedMonoDepths
Expand All @@ -16,6 +21,7 @@
)
from natsort import natsorted
from rich.console import Console
from rich.prompt import Confirm

from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import CAMERA_MODEL_TO_TYPE, Cameras
Expand All @@ -26,8 +32,11 @@
)
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.plugins.registry_dataparser import DataParserSpecification
from nerfstudio.data.utils import colmap_parsing_utils as colmap_utils
from nerfstudio.process_data.colmap_utils import colmap_to_json
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.rich_utils import CONSOLE, status
from nerfstudio.utils.scripts import run_command


MAX_AUTO_RESOLUTION = 1600
CONSOLE = Console()
Expand Down Expand Up @@ -77,7 +86,7 @@ class CoolerMapDataParserConfig(ColmapDataParserConfig):
"""The method to use to center the poses."""
auto_scale_poses: bool = False
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
downscale_factor: int = 1
downscale_factor: Optional[int] = None


class CoolerMapDataParser(ColmapDataParser):
Expand All @@ -93,10 +102,13 @@ def get_depth_filepaths(self):
depth_paths = natsorted(
glob.glob(f"{self.config.data}/mono_depth/*_aligned.npy")
)
depth_paths = [Path(depth_path) for depth_path in depth_paths]
return depth_paths

def get_normal_filepaths(self):
return natsorted(glob.glob(f"{self.normal_save_dir}/*.png"))
normal_paths = natsorted(glob.glob(f"{self.normal_save_dir}/*.png"))
normal_paths = [Path(normal_path) for normal_path in normal_paths]
return normal_paths

def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
assert (
Expand Down Expand Up @@ -149,7 +161,6 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
You should check that mask_path is specified for every frame (or zero frames) in transforms.json.
"""

depth_filenames = self.get_depth_filepaths()
poses = [
pose for img, pose in natsorted(zip(image_filenames, poses), lambda x: x[0])
]
Expand All @@ -175,7 +186,9 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
indices = indices[:: self.config.load_every]

metadata = {}

# load depths
depth_filenames = []
if self.config.depth_mode != "none" and self.config.load_depths:
if not (self.config.data / "mono_depth").exists():
CONSOLE.print(
Expand All @@ -185,21 +198,55 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
data=self.config.data, mono_depth_network=self.config.mono_pretrain
).main()
depth_filenames = self.get_depth_filepaths()

# load normals
normal_filenames = []
if self.config.normals_from == "depth":
self.normal_save_dir = self.config.data / Path("normals_from_depth")
else:
self.normal_save_dir = self.config.data / Path("normals_from_pretrain")

if self.config.load_normals:
if not (self.normal_save_dir).exists() or len(os.listdir(self.normal_save_dir)) == 0:
CONSOLE.print(
f"[bold yellow]Could not find normals, generating them into {str(self.normal_save_dir)}"
)
self.normal_save_dir.mkdir(exist_ok=True, parents=True)
if self.config.normals_from == "depth":
normals_from_depths(
path_to_transforms=Path(image_filenames[0]).parent.parent
/ "transforms.json",
normal_format=self.config.normal_format,
)
elif self.config.normals_from == "pretrained":
NormalsFromPretrained(data_dir=self.config.data).main()
else:
raise NotImplementedError
normal_filenames = self.get_normal_filepaths()


image_filenames, mask_filenames, depth_filenames, normal_filenames, downscale_factor = self._setup_downscale_factor(
image_filenames, mask_filenames, depth_filenames, normal_filenames
)

if self.config.load_depths:
metadata["mono_depth_filenames"] = [
Path(depth_filenames[i]) for i in indices
]

if self.config.load_normals:
metadata["normal_filenames"] = [
Path(normal_filenames[i]) for i in indices
]

image_filenames = [image_filenames[i] for i in indices]
mask_filenames = (
[mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else []
)
idx_tensor = torch.tensor(indices, dtype=torch.long)
poses = poses[idx_tensor]

if self.config.load_depths:
assert len(metadata["mono_depth_filenames"]) == len(image_filenames)

# in x,y,z order
# in x,y,z order
# assumes that the scene is centered at the origin
aabb_scale = self.config.scene_scale
scene_box = SceneBox(
Expand Down Expand Up @@ -232,7 +279,9 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
camera_type=camera_type,
)

# cameras.rescale_output_resolution(scaling_factor=1.0 / downscale_factor)
cameras.rescale_output_resolution(
scaling_factor=1.0 / downscale_factor, scale_rounding_mode=self.config.downscale_rounding_mode
)

if "applied_transform" in meta:
applied_transform = torch.tensor(
Expand All @@ -258,40 +307,9 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
metadata.update({"depth_mode": self.config.depth_mode})
metadata.update({"load_depths": self.config.load_depths})
metadata.update({"is_euclidean_depth": self.config.is_euclidean_depth})

# load normals
if self.config.normals_from == "depth":
self.normal_save_dir = self.config.data / Path("normals_from_depth")
else:
self.normal_save_dir = self.config.data / Path("normals_from_pretrain")

if self.config.load_normals and (
not (self.normal_save_dir).exists()
or len(os.listdir(self.normal_save_dir)) == 0
):
CONSOLE.print(
f"[bold yellow]Could not find normals, generating them into {str(self.normal_save_dir)}"
)
self.normal_save_dir.mkdir(exist_ok=True, parents=True)
if self.config.normals_from == "depth":
normals_from_depths(
path_to_transforms=Path(image_filenames[0]).parent.parent
/ "transforms.json",
normal_format=self.config.normal_format,
)
elif self.config.normals_from == "pretrained":
NormalsFromPretrained(data_dir=self.config.data).main()
else:
raise NotImplementedError

if self.config.load_normals:
normal_filenames = self.get_normal_filepaths()
metadata.update(
{"normal_filenames": [Path(normal_filenames[idx]) for idx in indices]}
)
metadata.update({"normal_format": self.config.normal_format})

metadata.update({"load_normals": self.config.load_normals})
metadata.update({"normal_format": self.config.normal_format})

if self.config.load_pcd_normals:
metadata.update(
self._load_points3D_normals(points_3d=metadata["points3D_xyz"])
Expand All @@ -317,6 +335,135 @@ def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
)

return dataparser_outputs

def _downscale_numpy(
self,
paths,
get_fname,
downscale_factor: int,
downscale_rounding_mode: str = "floor",
nearest_neighbor: bool = False,
):
def calculate_scaled_size(original_width, original_height, downscale_factor, mode="floor"):
if mode == "floor":
return math.floor(original_width / downscale_factor), math.floor(original_height / downscale_factor)
elif mode == "round":
return round(original_width / downscale_factor), round(original_height / downscale_factor)
elif mode == "ceil":
return math.ceil(original_width / downscale_factor), math.ceil(original_height / downscale_factor)
else:
raise ValueError("Invalid mode. Choose from 'floor', 'round', or 'ceil'.")

with status(msg="[bold yellow]Downscaling images...", spinner="growVertical"):
assert downscale_factor > 1
assert isinstance(downscale_factor, int)
filepath = next(iter(paths))
img = np.load(filepath)
w, h = img.shape[1], img.shape[0]
w_scaled, h_scaled = calculate_scaled_size(w, h, downscale_factor, downscale_rounding_mode)
# Using %05d ffmpeg commands appears to be unreliable (skips images).
for path in paths:
img = np.load(path)
img = cv2.resize(
img, (w_scaled, h_scaled), interpolation=cv2.INTER_NEAREST
)
path_out = get_fname(path)
path_out.parent.mkdir(parents=True, exist_ok=True)
np.save(path_out, img)

CONSOLE.log("[bold green]:tada: Done downscaling images.")

def _setup_downscale_factor(
self, image_filenames: List[Path], mask_filenames: List[Path], depth_filenames: List[Path], normal_filenames: List[Path]
):
"""
Setup the downscale factor for the dataset. This is used to downscale the images and cameras.
"""

def get_fname(parent: Path, filepath: Path) -> Path:
"""Returns transformed file name when downscale factor is applied"""
rel_part = filepath.relative_to(parent)
base_part = parent.parent / (str(parent.name) + f"_{self._downscale_factor}")
return base_part / rel_part

filepath = next(iter(image_filenames))
if self._downscale_factor is None:
if self.config.downscale_factor is None:
test_img = Image.open(filepath)
w, h = test_img.size
max_res = max(h, w)
df = 0
while True:
if (max_res / 2 ** (df)) <= MAX_AUTO_RESOLUTION:
break
df += 1

self._downscale_factor = 2**df
CONSOLE.log(f"Using image downscale factor of {self._downscale_factor}")
else:
self._downscale_factor = self.config.downscale_factor
if self._downscale_factor > 1 and not all(
get_fname(self.config.data / self.config.images_path, fp).parent.exists() for fp in image_filenames
):
# Downscaled images not found
# Ask if user wants to downscale the images automatically here
CONSOLE.print(
f"[bold red]Downscaled images do not exist for factor of {self._downscale_factor}.[/bold red]"
)
if Confirm.ask(
f"\nWould you like to downscale the images using '{self.config.downscale_rounding_mode}' rounding mode now?",
default=False,
console=CONSOLE,
):
# Install the method
self._downscale_images(
image_filenames,
partial(get_fname, self.config.data / self.config.images_path),
self._downscale_factor,
self.config.downscale_rounding_mode,
nearest_neighbor=False,
)
if len(mask_filenames) > 0:
assert self.config.masks_path is not None
self._downscale_images(
mask_filenames,
partial(get_fname, self.config.data / self.config.masks_path),
self._downscale_factor,
self.config.downscale_rounding_mode,
nearest_neighbor=True,
)
if len(depth_filenames) > 0:
self._downscale_numpy(
depth_filenames,
partial(get_fname, self.config.data / "mono_depth"),
self._downscale_factor,
self.config.downscale_rounding_mode,
nearest_neighbor=True,
)
if len(normal_filenames) > 0:
self._downscale_images(
normal_filenames,
partial(get_fname, self.normal_save_dir),
self._downscale_factor,
self.config.downscale_rounding_mode,
nearest_neighbor=True,
)
else:
sys.exit(1)

# Return transformed filenames
if self._downscale_factor > 1:
image_filenames = [get_fname(self.config.data / self.config.images_path, fp) for fp in image_filenames]
if len(mask_filenames) > 0:
assert self.config.masks_path is not None
mask_filenames = [get_fname(self.config.data / self.config.masks_path, fp) for fp in mask_filenames]
if len(depth_filenames) > 0:
depth_filenames = [get_fname(self.config.data / "mono_depth", fp) for fp in depth_filenames]
if len(normal_filenames) > 0:
normal_filenames = [get_fname(self.normal_save_dir, fp) for fp in normal_filenames]
assert isinstance(self._downscale_factor, int)
return image_filenames, mask_filenames, depth_filenames, normal_filenames, self._downscale_factor


def _load_points3D_normals(self, points_3d):
transform_matrix = torch.eye(4, dtype=torch.float, device="cpu")[:3, :4]
Expand Down