From 8bed243a33b0d89df6d51d5e8594b4022176050c Mon Sep 17 00:00:00 2001 From: Gina Wu Date: Thu, 10 Oct 2024 03:01:19 -0700 Subject: [PATCH] init keyframe value capture for image embeddings --- nerfstudio/viewer/render_panel.py | 53 +++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index cf05c59c6c..288a9b2e2b 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -20,6 +20,8 @@ import json import threading import time +import torch + from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union @@ -30,7 +32,10 @@ import viser.transforms as tf from scipy import interpolate +from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle +from nerfstudio.cameras.camera_utils import quaternion_matrix from nerfstudio.models.base_model import Model +from nerfstudio.utils import colormaps from nerfstudio.viewer.control_panel import ControlPanel if TYPE_CHECKING: @@ -66,9 +71,14 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe: class CameraPath: def __init__( - self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False + self, + server: viser.ViserServer, + duration_element: viser.GuiInputHandle[float], + viewer_model: Model, + time_enabled: bool = False ): self._server = server + self._model = viewer_model self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {} self._keyframe_counter: int = 0 self._spline_nodes: List[viser.SceneNodeHandle] = [] @@ -167,6 +177,7 @@ def _(_) -> None: delete_button = server.gui.add_button("Delete", color="red", icon=viser.Icon.TRASH) go_to_button = server.gui.add_button("Go to") + capture_button = server.gui.add_button("Capture") close_button = server.gui.add_button("Close") @override_fov.on_update @@ -232,6 +243,42 @@ def _(event: viser.GuiEvent) -> None: client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() time.sleep(1.0 / 30.0) + + @capture_button.on_click + def _(event: viser.GuiEvent) -> None: + R = quaternion_matrix(keyframe.wxyz) + forward_vector = np.array([0, 0, 1, 1]) + direction = R @ forward_vector + direction = [direction[:3]] + + image_height = 607.0 + image_width = 1060.0 + + from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO + + c2w = tf.SE3.from_rotation_and_translation( + tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi), + keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO, + ).as_matrix() + c2w = torch.tensor(c2w[:3]) + + camera = Cameras( + fx=image_width / 2, + fy=image_height, + cx=image_width / 2, + cy=image_height / 2, + camera_to_worlds=c2w, + camera_type=CameraType.PERSPECTIVE, + times=None, + ) + + self._model.training = False + outputs = self._model.get_outputs_for_camera(camera) + + from PIL import Image + _im = outputs['rgb'].detach().numpy() + im = Image.fromarray((_im * 255).astype(np.uint8)) + im.save("capture.png") @close_button.on_click def _(_) -> None: @@ -1169,9 +1216,9 @@ def _(_) -> None: modal.close() if control_panel is not None: - camera_path = CameraPath(server, duration_number, control_panel._time_enabled) + camera_path = CameraPath(server, duration_number, viewer_model, control_panel._time_enabled) else: - camera_path = CameraPath(server, duration_number) + camera_path = CameraPath(server, duration_number, viewer_model) camera_path.tension = tension_slider.value camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value