From 4d87903a2726a943883e1553fd0edd3dace6b560 Mon Sep 17 00:00:00 2001 From: Rebecca Feng Date: Mon, 8 Jul 2024 15:34:31 -0700 Subject: [PATCH 1/2] pr changes --- examples/24_gaussian_splats.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/24_gaussian_splats.py b/examples/24_gaussian_splats.py index aab68dab..06bd5f70 100644 --- a/examples/24_gaussian_splats.py +++ b/examples/24_gaussian_splats.py @@ -2,7 +2,6 @@ from __future__ import annotations -import sys import time from pathlib import Path from typing import TypedDict @@ -44,7 +43,7 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: ) assert len(splat_buffer) % bytes_per_gaussian == 0 num_gaussians = len(splat_buffer) // bytes_per_gaussian - print(f"{num_gaussians=}") + print("Number of gaussians to render: ", f"{num_gaussians=}") # Reinterpret cast to dtypes that we want to extract. splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape( @@ -52,7 +51,6 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: ) scales = splat_uint8[:, 12:24].copy().view(onp.float32) wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 - print(onp.shape(wxyzs)) Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs]) covariances = onp.einsum( "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs @@ -60,7 +58,7 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: centers = splat_uint8[:, 0:12].copy().view(onp.float32) if center: centers -= onp.mean(centers, axis=0, keepdims=True) - print("loaded") + print("Render loaded") return { "centers": centers, # Colors should have shape (N, 3). @@ -78,11 +76,13 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: -onp.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) / (1 + onp.exp(-vert["opacity"])) ) - colors = onp.zeros((len(sorted_indices), 3)) - opacities = onp.zeros((len(sorted_indices), 1)) - positions = onp.zeros((len(sorted_indices), 3)) - wxyzs = onp.zeros((len(sorted_indices), 4)) - scales = onp.zeros((len(sorted_indices), 3)) + numgaussians = len(vert) + print("Number of gaussians to render: ", numgaussians) + colors = onp.zeros((numgaussians, 3)) + opacities = onp.zeros((numgaussians, 1)) + positions = onp.zeros((numgaussians, 3)) + wxyzs = onp.zeros((numgaussians, 4)) + scales = onp.zeros((numgaussians, 3)) for idx in sorted_indices: v = plydata["vertex"][idx] position = onp.array([v["x"], v["y"], v["z"]], dtype=onp.float32) @@ -102,9 +102,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: ] ) opacity = 1 / (1 + onp.exp(-v["opacity"])) - wxyz = ((rot / onp.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype( - onp.uint8 - ) / 255.0 * 2.0 - 1.0 + wxyz = rot / onp.linalg.norm(rot) # normalize scales[idx] = scale colors[idx] = color opacities[idx] = onp.array([opacity]) @@ -117,7 +115,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: ) if center: positions -= onp.mean(positions, axis=0, keepdims=True) - print("loaded") + print("Render loaded") return { "centers": positions, # Colors should have shape (N, 3). @@ -150,7 +148,7 @@ def _(event: viser.GuiEvent) -> None: elif splat_path.suffix == ".ply": splat_data = load_ply_file(splat_path, center=True) else: - sys.exit("Please provide a filepath to a .splat or .ply file.") + raise SystemExit("Please provide a filepath to a .splat or .ply file.") server.scene.add_transform_controls(f"/{i}") server.scene.add_gaussian_splats( From 970931656e780a68986e82d9ad4b1958280af7bc Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 9 Jul 2024 15:34:10 +0900 Subject: [PATCH 2/2] nit --- examples/24_gaussian_splats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/24_gaussian_splats.py b/examples/24_gaussian_splats.py index 06bd5f70..6889aaca 100644 --- a/examples/24_gaussian_splats.py +++ b/examples/24_gaussian_splats.py @@ -58,7 +58,7 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: centers = splat_uint8[:, 0:12].copy().view(onp.float32) if center: centers -= onp.mean(centers, axis=0, keepdims=True) - print("Render loaded") + print("Splat file loaded") return { "centers": centers, # Colors should have shape (N, 3). @@ -115,7 +115,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: ) if center: positions -= onp.mean(positions, axis=0, keepdims=True) - print("Render loaded") + print("PLY file loaded") return { "centers": positions, # Colors should have shape (N, 3).