diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 10d263f8c2..288a9b2e2b 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -20,8 +20,10 @@ import json import threading import time +import torch + from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import numpy as np import splines @@ -30,8 +32,15 @@ import viser.transforms as tf from scipy import interpolate +from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle +from nerfstudio.cameras.camera_utils import quaternion_matrix +from nerfstudio.models.base_model import Model +from nerfstudio.utils import colormaps from nerfstudio.viewer.control_panel import ControlPanel +if TYPE_CHECKING: + from viser import GuiInputHandle + @dataclasses.dataclass class Keyframe: @@ -62,9 +71,14 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe: class CameraPath: def __init__( - self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False + self, + server: viser.ViserServer, + duration_element: viser.GuiInputHandle[float], + viewer_model: Model, + time_enabled: bool = False ): self._server = server + self._model = viewer_model self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {} self._keyframe_counter: int = 0 self._spline_nodes: List[viser.SceneNodeHandle] = [] @@ -163,6 +177,7 @@ def _(_) -> None: delete_button = server.gui.add_button("Delete", color="red", icon=viser.Icon.TRASH) go_to_button = server.gui.add_button("Go to") + capture_button = server.gui.add_button("Capture") close_button = server.gui.add_button("Close") @override_fov.on_update @@ -228,6 +243,42 @@ def _(event: viser.GuiEvent) -> None: client.camera.wxyz = T_world_set.rotation().wxyz client.camera.position = T_world_set.translation() time.sleep(1.0 / 30.0) + + @capture_button.on_click + def _(event: viser.GuiEvent) -> None: + R = quaternion_matrix(keyframe.wxyz) + forward_vector = np.array([0, 0, 1, 1]) + direction = R @ forward_vector + direction = [direction[:3]] + + image_height = 607.0 + image_width = 1060.0 + + from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO + + c2w = tf.SE3.from_rotation_and_translation( + tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi), + keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO, + ).as_matrix() + c2w = torch.tensor(c2w[:3]) + + camera = Cameras( + fx=image_width / 2, + fy=image_height, + cx=image_width / 2, + cy=image_height / 2, + camera_to_worlds=c2w, + camera_type=CameraType.PERSPECTIVE, + times=None, + ) + + self._model.training = False + outputs = self._model.get_outputs_for_camera(camera) + + from PIL import Image + _im = outputs['rgb'].detach().numpy() + im = Image.fromarray((_im * 255).astype(np.uint8)) + im.save("capture.png") @close_button.on_click def _(_) -> None: @@ -520,8 +571,9 @@ def populate_render_tab( server: viser.ViserServer, config_path: Path, datapath: Path, + viewer_model: Model, control_panel: Optional[ControlPanel] = None, -) -> RenderTabState: +) -> Tuple[RenderTabState, CameraPath, GuiInputHandle, GuiInputHandle]: from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO render_tab_state = RenderTabState( @@ -588,6 +640,7 @@ def _(_) -> None: initial_value="Perspective", hint="Camera model to render with. This is applied to all keyframes.", ) + add_button = server.gui.add_button( "Add Keyframe", icon=viser.Icon.PLUS, @@ -1163,14 +1216,14 @@ def _(_) -> None: modal.close() if control_panel is not None: - camera_path = CameraPath(server, duration_number, control_panel._time_enabled) + camera_path = CameraPath(server, duration_number, viewer_model, control_panel._time_enabled) else: - camera_path = CameraPath(server, duration_number) + camera_path = CameraPath(server, duration_number, viewer_model) camera_path.tension = tension_slider.value camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value - return render_tab_state + return render_tab_state, camera_path, duration_number, resolution if __name__ == "__main__": @@ -1178,6 +1231,7 @@ def _(_) -> None: server=viser.ViserServer(), config_path=Path("."), datapath=Path("."), + viewer_model=Model, ) while True: time.sleep(10.0) diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index bc58043aa6..1ca514d73b 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -94,7 +94,9 @@ def __init__( self.include_time = self.pipeline.datamanager.includes_time if self.config.websocket_port is None: - websocket_port = viewer_utils.get_free_port(default_port=self.config.websocket_port_default) + websocket_port = viewer_utils.get_free_port( + default_port=self.config.websocket_port_default + ) else: websocket_port = self.config.websocket_port self.log_filename.parent.mkdir(exist_ok=True) @@ -106,12 +108,16 @@ def __init__( self.train_btn_state: Literal["training", "paused", "completed"] = ( "training" if self.trainer is None else self.trainer.training_state ) - self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state + self._prev_train_state: Literal["training", "paused", "completed"] = ( + self.train_btn_state + ) self.last_move_time = 0 # track the camera index that last being clicked self.current_camera_idx = 0 - self.viser_server = viser.ViserServer(host=config.websocket_host, port=websocket_port) + self.viser_server = viser.ViserServer( + host=config.websocket_host, port=websocket_port + ) # Set the name of the URL either to the share link if available, or the localhost share_url = None if share: @@ -120,15 +126,21 @@ def __init__( print("Couldn't make share URL!") if share_url is not None: - self.viewer_info = [f"Viewer at: http://localhost:{websocket_port} or {share_url}"] + self.viewer_info = [ + f"Viewer at: http://localhost:{websocket_port} or {share_url}" + ] elif config.websocket_host == "0.0.0.0": # 0.0.0.0 is not a real IP address and was confusing people, so # we'll just print localhost instead. There are some security # (and IPv6 compatibility) implications here though, so we should # note that the server is bound to 0.0.0.0! - self.viewer_info = [f"Viewer running locally at: http://localhost:{websocket_port} (listening on 0.0.0.0)"] + self.viewer_info = [ + f"Viewer running locally at: http://localhost:{websocket_port} (listening on 0.0.0.0)" + ] else: - self.viewer_info = [f"Viewer running locally at: http://{config.websocket_host}:{websocket_port}"] + self.viewer_info = [ + f"Viewer running locally at: http://{config.websocket_host}:{websocket_port}" + ] buttons = ( viser.theme.TitlebarButton( @@ -195,8 +207,8 @@ def __init__( self.show_images.visible = False mkdown = self.make_stats_markdown(0, "0x0px") self.stats_markdown = self.viser_server.gui.add_markdown(mkdown) - tabs = self.viser_server.gui.add_tab_group() - control_tab = tabs.add_tab("Control", viser.Icon.SETTINGS) + self.tabs = self.viser_server.gui.add_tab_group() + control_tab = self.tabs.add_tab("Control", viser.Icon.SETTINGS) with control_tab: self.control_panel = ControlPanel( self.viser_server, @@ -208,13 +220,25 @@ def __init__( default_composite_depth=self.config.default_composite_depth, ) config_path = self.log_filename.parents[0] / "config.yml" - with tabs.add_tab("Render", viser.Icon.CAMERA): - self.render_tab_state = populate_render_tab( - self.viser_server, config_path, self.datapath, self.control_panel + + with self.tabs.add_tab("Render", viser.Icon.CAMERA): + ( + self.render_tab_state, + self.camera_path, + self.duration_number, + self.resolution, + ) = populate_render_tab( + self.viser_server, + config_path, + self.datapath, + self.pipeline.model, + self.control_panel, ) - with tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): - populate_export_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + with self.tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): + populate_export_tab( + self.viser_server, self.control_panel, config_path, self.pipeline.model + ) # Keep track of the pointers to generated GUI folders, because each generated folder holds a unique ID. viewer_gui_folders = dict() @@ -224,17 +248,26 @@ def prev_cb_wrapper(prev_cb): # concurrently executing render thread. This may block rendering, however this can be necessary # if the callback uses get_outputs internally. def cb_lock(element): - with self.train_lock if self.train_lock is not None else contextlib.nullcontext(): + with ( + self.train_lock + if self.train_lock is not None + else contextlib.nullcontext() + ): prev_cb(element) return cb_lock - def nested_folder_install(folder_labels: List[str], prev_labels: List[str], element: ViewerElement): + def nested_folder_install( + folder_labels: List[str], prev_labels: List[str], element: ViewerElement + ): if len(folder_labels) == 0: element.install(self.viser_server) # also rewire the hook to rerender prev_cb = element.cb_hook - element.cb_hook = lambda element: [prev_cb_wrapper(prev_cb)(element), self._trigger_rerender()] + element.cb_hook = lambda element: [ + prev_cb_wrapper(prev_cb)(element), + self._trigger_rerender(), + ] else: # recursively create folders # If the folder name is "Custom Elements/a/b", then: @@ -250,12 +283,18 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem # Otherwise, use the existing folder as context manager. folder_path = "/".join(prev_labels + [folder_labels[0]]) if folder_path not in viewer_gui_folders: - viewer_gui_folders[folder_path] = self.viser_server.gui.add_folder(folder_labels[0]) + viewer_gui_folders[folder_path] = self.viser_server.gui.add_folder( + folder_labels[0] + ) with viewer_gui_folders[folder_path]: - nested_folder_install(folder_labels[1:], prev_labels + [folder_labels[0]], element) + nested_folder_install( + folder_labels[1:], prev_labels + [folder_labels[0]], element + ) with control_tab: - from nerfstudio.viewer_legacy.server.viewer_elements import ViewerElement as LegacyViewerElement + from nerfstudio.viewer_legacy.server.viewer_elements import ( + ViewerElement as LegacyViewerElement, + ) if len(parse_object(pipeline, LegacyViewerElement, "Custom Elements")) > 0: from nerfstudio.utils.rich_utils import CONSOLE @@ -265,7 +304,9 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem style="bold yellow", ) self.viewer_elements = [] - self.viewer_elements.extend(parse_object(pipeline, ViewerElement, "Custom Elements")) + self.viewer_elements.extend( + parse_object(pipeline, ViewerElement, "Custom Elements") + ) for param_path, element in self.viewer_elements: folder_labels = param_path.split("/")[:-1] nested_folder_install(folder_labels, [], element) @@ -282,7 +323,8 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem if isinstance(pipeline.model, SplatfactoModel): self.viser_server.scene.add_point_cloud( "/gaussian_splatting_initial_points", - points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO, + points=pipeline.model.means.numpy(force=True) + * VISER_NERFSTUDIO_SCALE_RATIO, colors=(255, 0, 0), point_size=0.01, point_shape="circle", @@ -317,7 +359,10 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: R = vtf.SO3(wxyz=client.camera.wxyz) R = R @ vtf.SO3.from_x_radians(np.pi) R = torch.tensor(R.as_matrix()) - pos = torch.tensor(client.camera.position, dtype=torch.float64) / VISER_NERFSTUDIO_SCALE_RATIO + pos = ( + torch.tensor(client.camera.position, dtype=torch.float64) + / VISER_NERFSTUDIO_SCALE_RATIO + ) c2w = torch.concatenate([R, pos[:, None]], dim=1) if self.ready and self.render_tab_state.preview_render: camera_type = self.render_tab_state.preview_camera_type @@ -326,13 +371,19 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: aspect=self.render_tab_state.preview_aspect, c2w=c2w, time=self.render_tab_state.preview_time, - camera_type=CameraType.PERSPECTIVE - if camera_type == "Perspective" - else CameraType.FISHEYE - if camera_type == "Fisheye" - else CameraType.EQUIRECTANGULAR - if camera_type == "Equirectangular" - else assert_never(camera_type), + camera_type=( + CameraType.PERSPECTIVE + if camera_type == "Perspective" + else ( + CameraType.FISHEYE + if camera_type == "Fisheye" + else ( + CameraType.EQUIRECTANGULAR + if camera_type == "Equirectangular" + else assert_never(camera_type) + ) + ) + ), idx=self.current_camera_idx, ) else: @@ -350,7 +401,9 @@ def handle_disconnect(self, client: viser.ClientHandle) -> None: self.render_statemachines.pop(client.client_id) def handle_new_client(self, client: viser.ClientHandle) -> None: - self.render_statemachines[client.client_id] = RenderStateMachine(self, VISER_NERFSTUDIO_SCALE_RATIO, client) + self.render_statemachines[client.client_id] = RenderStateMachine( + self, VISER_NERFSTUDIO_SCALE_RATIO, client + ) self.render_statemachines[client.client_id].start() @client.camera.on_update @@ -360,7 +413,9 @@ def _(_: viser.CameraHandle) -> None: self.last_move_time = time.time() with self.viser_server.atomic(): camera_state = self.get_camera_state(client) - self.render_statemachines[client.client_id].action(RenderAction("move", camera_state)) + self.render_statemachines[client.client_id].action( + RenderAction("move", camera_state) + ) def set_camera_visibility(self, visible: bool) -> None: """Toggle the visibility of the training cameras.""" @@ -381,15 +436,23 @@ def update_camera_poses(self): idxs = list(self.camera_handles.keys()) with torch.no_grad(): assert isinstance(camera_optimizer, CameraOptimizer) - c2ws_delta = camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device)).cpu().numpy() + c2ws_delta = ( + camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device)) + .cpu() + .numpy() + ) for i, key in enumerate(idxs): # both are numpy arrays c2w_orig = self.original_c2w[key] c2w_delta = c2ws_delta[i, ...] - c2w = c2w_orig @ np.concatenate((c2w_delta, np.array([[0, 0, 0, 1]])), axis=0) + c2w = c2w_orig @ np.concatenate( + (c2w_delta, np.array([[0, 0, 0, 1]])), axis=0 + ) R = vtf.SO3.from_matrix(c2w[:3, :3]) # type: ignore R = R @ vtf.SO3.from_x_radians(np.pi) - self.camera_handles[key].position = c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO + self.camera_handles[key].position = ( + c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO + ) self.camera_handles[key].wxyz = R.wxyz def _trigger_rerender(self) -> None: @@ -429,7 +492,9 @@ def _pick_drawn_image_idxs(self, total_num: int) -> list[int]: else: num_display_images = min(self.config.max_num_display_images, total_num) # draw indices, roughly evenly spaced - return np.linspace(0, total_num - 1, num_display_images, dtype=np.int32).tolist() + return np.linspace( + 0, total_num - 1, num_display_images, dtype=np.int32 + ).tolist() def init_scene( self, @@ -473,7 +538,9 @@ def init_scene( ) def create_on_click_callback(capture_idx): - def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None: + def on_click_callback( + event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle], + ) -> None: with event.client.atomic(): event.client.camera.position = event.target.position event.client.camera.wxyz = event.target.wxyz @@ -503,12 +570,18 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N # this stops training while moving to make the response smoother while time.time() - self.last_move_time < 0.1: time.sleep(0.05) - if self.trainer is not None and self.trainer.training_state == "training" and self.train_util != 1: + if ( + self.trainer is not None + and self.trainer.training_state == "training" + and self.train_util != 1 + ): if ( EventName.TRAIN_RAYS_PER_SEC.value in GLOBAL_BUFFER["events"] and EventName.VIS_RAYS_PER_SEC.value in GLOBAL_BUFFER["events"] ): - train_s = GLOBAL_BUFFER["events"][EventName.TRAIN_RAYS_PER_SEC.value]["avg"] + train_s = GLOBAL_BUFFER["events"][EventName.TRAIN_RAYS_PER_SEC.value][ + "avg" + ] vis_s = GLOBAL_BUFFER["events"][EventName.VIS_RAYS_PER_SEC.value]["avg"] train_util = self.train_util vis_n = self.control_panel.max_res**2 @@ -516,7 +589,9 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N train_time = train_n / train_s vis_time = vis_n / vis_s - render_freq = train_util * vis_time / (train_time - train_util * train_time) + render_freq = ( + train_util * vis_time / (train_time - train_util * train_time) + ) else: render_freq = 30 if step > self.last_step + render_freq: @@ -525,7 +600,9 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N for id in clients: camera_state = self.get_camera_state(clients[id]) if camera_state is not None: - self.render_statemachines[id].action(RenderAction("step", camera_state)) + self.render_statemachines[id].action( + RenderAction("step", camera_state) + ) self.update_camera_poses() self.update_step(step)