diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 12958732d3..59cd36cf07 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -21,7 +21,7 @@ import threading import time from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import numpy as np import splines @@ -33,6 +33,9 @@ from nerfstudio.models.base_model import Model from nerfstudio.viewer.control_panel import ControlPanel +if TYPE_CHECKING: + from viser import GuiInputHandle + @dataclasses.dataclass class Keyframe: @@ -523,7 +526,7 @@ def populate_render_tab( datapath: Path, viewer_model: Model, control_panel: Optional[ControlPanel] = None, -) -> RenderTabState: +) -> Tuple[RenderTabState, GuiInputHandle, GuiInputHandle, GuiInputHandle]: from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO render_tab_state = RenderTabState( @@ -1164,216 +1167,6 @@ def _(event: viser.GuiEvent) -> None: @close_button.on_click def _(_) -> None: modal.close() - - auto_camera_folder = server.gui.add_folder("Automatic Camera Path") - with auto_camera_folder: - click_position = np.array([0.0, 0.0, -5.0]) - select_center_button = server.gui.add_button( - "Select Center", - icon=viser.Icon.CROSSHAIR, - hint="Choose center point to generate camera path around.", - ) - - @select_center_button.on_click - def _(event: viser.GuiEvent) -> None: - select_center_button.disabled = True - - @event.client.scene.on_pointer_event(event_type="click") - def _(event: viser.ScenePointerEvent) -> None: - # Code mostly borrowed from garfield.studio! - import torch - from nerfstudio.cameras.rays import RayBundle - from nerfstudio.field_components.field_heads import FieldHeadNames - from nerfstudio.model_components.losses import scale_gradients_by_distance_squared - - origin = torch.tensor(event.ray_origin).view(1, 3) - direction = torch.tensor(event.ray_direction).view(1, 3) - - # Get intersection - bundle = RayBundle( - origins=origin, - directions=direction, - pixel_area=torch.tensor(0.001).view(1, 1), - camera_indices=torch.tensor(0).view(1, 1), - nears=torch.tensor(0.05).view(1, 1), - fars=torch.tensor(100).view(1, 1), - ).to("cuda") - - # Get the distance/depth to the intersection --> calculate 3D position of the click - ray_samples, _, _ = viewer_model.proposal_sampler(bundle, density_fns=viewer_model.density_fns) - field_outputs = viewer_model.field.forward(ray_samples, compute_normals=viewer_model.config.predict_normals) - if viewer_model.config.use_gradient_scaling: - field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples) - weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY]) - with torch.no_grad(): - depth = viewer_model.renderer_depth(weights=weights, ray_samples=ray_samples) - distance = depth[0, 0].detach().cpu().numpy() - - nonlocal click_position - click_position = np.array(origin + direction * distance).reshape(3,) - - server.scene.add_icosphere( - f"/render_center_pos", - radius=0.1, - color=(200, 10, 30), - position=click_position, - ) - - event.client.scene.remove_pointer_callback() - - @event.client.scene.on_pointer_callback_removed - def _(): - select_center_button.disabled = False - - num_cameras_handle = server.gui.add_number( - label="Number of Cameras", - initial_value=3, - hint="Total number of cameras generated in path, placed equidistant from neighboring ones.", - ) - - radius_handle = server.gui.add_number( - label="Radius", - initial_value=4, - hint="Radius of circular camera path.", - ) - - camera_height_handle = server.gui.add_number( - label="Height", - initial_value=2, - hint="Height of cameras with respect to chosen origin.", - ) - - circular_camera_path_button = server.gui.add_button( - "Generate Circular Camera Path", - icon=viser.Icon.CAMERA, - hint="Automatically generate a circular camera path around selected point.", - ) - - def wxyz_helper(camera_position: np.ndarray) -> np.ndarray: - # Calculates the camera direction from position to click_position - camera_direction = camera_position - click_position - camera_direction = camera_direction / np.linalg.norm(camera_direction) - - global_up = np.array([0.0, 0.0, 1.0]) - - camera_right = np.cross(camera_direction, global_up) - camera_right_norm = np.linalg.norm(camera_right) - if camera_right_norm > 0: - camera_right = camera_right / camera_right_norm - - camera_up = np.cross(camera_right, camera_direction) - - R = np.array([-camera_right, -camera_up, -camera_direction]).T - - w = np.sqrt(1 + R[0, 0] + R[1, 1] + R[2, 2]) / 2 - x = (R[2, 1] - R[1, 2]) / (4 * w) - y = (R[0, 2] - R[2, 0]) / (4 * w) - z = (R[1, 0] - R[0, 1]) / (4 * w) - return np.array([w, x, y, z]) - else: - return np.array([1.0, 0.0, 0.0, 0.0]) - - camera_coords = [] - @circular_camera_path_button.on_click - def _(event: viser.GuiEvent) -> None: - nonlocal click_position, num_cameras_handle, radius_handle, camera_height_handle, camera_coords - num_cameras = num_cameras_handle.value - radius = radius_handle.value - camera_height = camera_height_handle.value - - camera_coords = [] - for i in range(num_cameras): - camera_coords.append(click_position + - np.array([radius * np.cos(2 * np.pi * i / num_cameras), - radius * np.sin(2 * np.pi * i/ num_cameras), - camera_height])) - - fov = event.client.camera.fov - for i, position in enumerate(camera_coords): - camera_path.add_camera( - keyframe=Keyframe( - position=position, - wxyz=wxyz_helper(position), - override_fov_enabled=False, - override_fov_rad=fov, - override_time_enabled=False, - override_time_val=0.0, - aspect=resolution.value[0] / resolution.value[1], - override_transition_enabled=False, - override_transition_sec=None, - ), - keyframe_index = i, - ) - duration_number.value = camera_path.compute_duration() - camera_path.update_spline() - - optimize_button = server.gui.add_button( - "Optimize Camera Path", - icon=viser.Icon.CAMERA, - hint="Optimizes camera path for object avoidance iteratively.", - ) - - @optimize_button.on_click - def _(event: viser.GuiEvent) -> None: - import torch - from nerfstudio.cameras.rays import RayBundle - from nerfstudio.field_components.field_heads import FieldHeadNames - from nerfstudio.model_components.losses import scale_gradients_by_distance_squared - - nonlocal camera_coords - - directions = [[1, 0, 0], - [-1, 0, 0], - [0, 1, 0], - [0, -1, 0], - [0, 0, 1], - [0, 0, -1]] - - for i, position in enumerate(camera_coords): - raylen = 2.0 - origins = torch.tensor(np.tile(position, (6, 1))) - pixel_area = torch.ones_like(origins[..., 0:1]) - camera_indices = torch.zeros_like(origins[..., 0:1]).int() - nears = torch.zeros_like(origins[..., 0:1]) - fars = torch.ones_like(origins[..., 0:1]) * raylen - directions_norm = torch.ones_like(origins[..., 0:1]) - viewer_model.training = False - - bundle = RayBundle( - origins=origins, - directions=torch.tensor(directions), - pixel_area=pixel_area, - camera_indices=camera_indices, - nears=nears, - fars=fars, - metadata={"directions_norm": directions_norm}, - ).to('cuda') - - outputs = viewer_model.get_outputs(bundle) - distances = outputs["expected_depth"].detach().cpu().numpy() - - loss = -min(distances) - if loss > -0.4: - position = position - directions[np.argmin(distances)] * 1 - # backprop through the nerf as the gradient step, input is position - camera_path.add_camera( - keyframe=Keyframe( - position=position, - wxyz=wxyz_helper(position), - override_fov_enabled=False, - override_fov_rad=event.client.camera.fov, - override_time_enabled=False, - override_time_val=0.0, - aspect=resolution.value[0] / resolution.value[1], - override_transition_enabled=False, - override_transition_sec=None, - ), - keyframe_index = i, - ) - duration_number.value = camera_path.compute_duration() - camera_path.update_spline() - - camera_coords[i] = position if control_panel is not None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) @@ -1383,7 +1176,7 @@ def _(event: viser.GuiEvent) -> None: camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value - return render_tab_state + return render_tab_state, camera_path, duration_number, resolution if __name__ == "__main__": diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index 4137dd44ef..1ca514d73b 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -94,7 +94,9 @@ def __init__( self.include_time = self.pipeline.datamanager.includes_time if self.config.websocket_port is None: - websocket_port = viewer_utils.get_free_port(default_port=self.config.websocket_port_default) + websocket_port = viewer_utils.get_free_port( + default_port=self.config.websocket_port_default + ) else: websocket_port = self.config.websocket_port self.log_filename.parent.mkdir(exist_ok=True) @@ -106,12 +108,16 @@ def __init__( self.train_btn_state: Literal["training", "paused", "completed"] = ( "training" if self.trainer is None else self.trainer.training_state ) - self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state + self._prev_train_state: Literal["training", "paused", "completed"] = ( + self.train_btn_state + ) self.last_move_time = 0 # track the camera index that last being clicked self.current_camera_idx = 0 - self.viser_server = viser.ViserServer(host=config.websocket_host, port=websocket_port) + self.viser_server = viser.ViserServer( + host=config.websocket_host, port=websocket_port + ) # Set the name of the URL either to the share link if available, or the localhost share_url = None if share: @@ -120,15 +126,21 @@ def __init__( print("Couldn't make share URL!") if share_url is not None: - self.viewer_info = [f"Viewer at: http://localhost:{websocket_port} or {share_url}"] + self.viewer_info = [ + f"Viewer at: http://localhost:{websocket_port} or {share_url}" + ] elif config.websocket_host == "0.0.0.0": # 0.0.0.0 is not a real IP address and was confusing people, so # we'll just print localhost instead. There are some security # (and IPv6 compatibility) implications here though, so we should # note that the server is bound to 0.0.0.0! - self.viewer_info = [f"Viewer running locally at: http://localhost:{websocket_port} (listening on 0.0.0.0)"] + self.viewer_info = [ + f"Viewer running locally at: http://localhost:{websocket_port} (listening on 0.0.0.0)" + ] else: - self.viewer_info = [f"Viewer running locally at: http://{config.websocket_host}:{websocket_port}"] + self.viewer_info = [ + f"Viewer running locally at: http://{config.websocket_host}:{websocket_port}" + ] buttons = ( viser.theme.TitlebarButton( @@ -195,8 +207,8 @@ def __init__( self.show_images.visible = False mkdown = self.make_stats_markdown(0, "0x0px") self.stats_markdown = self.viser_server.gui.add_markdown(mkdown) - tabs = self.viser_server.gui.add_tab_group() - control_tab = tabs.add_tab("Control", viser.Icon.SETTINGS) + self.tabs = self.viser_server.gui.add_tab_group() + control_tab = self.tabs.add_tab("Control", viser.Icon.SETTINGS) with control_tab: self.control_panel = ControlPanel( self.viser_server, @@ -209,13 +221,24 @@ def __init__( ) config_path = self.log_filename.parents[0] / "config.yml" - with tabs.add_tab("Render", viser.Icon.CAMERA): - self.render_tab_state = populate_render_tab( - self.viser_server, config_path, self.datapath, self.pipeline.model, self.control_panel + with self.tabs.add_tab("Render", viser.Icon.CAMERA): + ( + self.render_tab_state, + self.camera_path, + self.duration_number, + self.resolution, + ) = populate_render_tab( + self.viser_server, + config_path, + self.datapath, + self.pipeline.model, + self.control_panel, ) - with tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): - populate_export_tab(self.viser_server, self.control_panel, config_path, self.pipeline.model) + with self.tabs.add_tab("Export", viser.Icon.PACKAGE_EXPORT): + populate_export_tab( + self.viser_server, self.control_panel, config_path, self.pipeline.model + ) # Keep track of the pointers to generated GUI folders, because each generated folder holds a unique ID. viewer_gui_folders = dict() @@ -225,17 +248,26 @@ def prev_cb_wrapper(prev_cb): # concurrently executing render thread. This may block rendering, however this can be necessary # if the callback uses get_outputs internally. def cb_lock(element): - with self.train_lock if self.train_lock is not None else contextlib.nullcontext(): + with ( + self.train_lock + if self.train_lock is not None + else contextlib.nullcontext() + ): prev_cb(element) return cb_lock - def nested_folder_install(folder_labels: List[str], prev_labels: List[str], element: ViewerElement): + def nested_folder_install( + folder_labels: List[str], prev_labels: List[str], element: ViewerElement + ): if len(folder_labels) == 0: element.install(self.viser_server) # also rewire the hook to rerender prev_cb = element.cb_hook - element.cb_hook = lambda element: [prev_cb_wrapper(prev_cb)(element), self._trigger_rerender()] + element.cb_hook = lambda element: [ + prev_cb_wrapper(prev_cb)(element), + self._trigger_rerender(), + ] else: # recursively create folders # If the folder name is "Custom Elements/a/b", then: @@ -251,12 +283,18 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem # Otherwise, use the existing folder as context manager. folder_path = "/".join(prev_labels + [folder_labels[0]]) if folder_path not in viewer_gui_folders: - viewer_gui_folders[folder_path] = self.viser_server.gui.add_folder(folder_labels[0]) + viewer_gui_folders[folder_path] = self.viser_server.gui.add_folder( + folder_labels[0] + ) with viewer_gui_folders[folder_path]: - nested_folder_install(folder_labels[1:], prev_labels + [folder_labels[0]], element) + nested_folder_install( + folder_labels[1:], prev_labels + [folder_labels[0]], element + ) with control_tab: - from nerfstudio.viewer_legacy.server.viewer_elements import ViewerElement as LegacyViewerElement + from nerfstudio.viewer_legacy.server.viewer_elements import ( + ViewerElement as LegacyViewerElement, + ) if len(parse_object(pipeline, LegacyViewerElement, "Custom Elements")) > 0: from nerfstudio.utils.rich_utils import CONSOLE @@ -266,7 +304,9 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem style="bold yellow", ) self.viewer_elements = [] - self.viewer_elements.extend(parse_object(pipeline, ViewerElement, "Custom Elements")) + self.viewer_elements.extend( + parse_object(pipeline, ViewerElement, "Custom Elements") + ) for param_path, element in self.viewer_elements: folder_labels = param_path.split("/")[:-1] nested_folder_install(folder_labels, [], element) @@ -283,7 +323,8 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem if isinstance(pipeline.model, SplatfactoModel): self.viser_server.scene.add_point_cloud( "/gaussian_splatting_initial_points", - points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO, + points=pipeline.model.means.numpy(force=True) + * VISER_NERFSTUDIO_SCALE_RATIO, colors=(255, 0, 0), point_size=0.01, point_shape="circle", @@ -318,7 +359,10 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: R = vtf.SO3(wxyz=client.camera.wxyz) R = R @ vtf.SO3.from_x_radians(np.pi) R = torch.tensor(R.as_matrix()) - pos = torch.tensor(client.camera.position, dtype=torch.float64) / VISER_NERFSTUDIO_SCALE_RATIO + pos = ( + torch.tensor(client.camera.position, dtype=torch.float64) + / VISER_NERFSTUDIO_SCALE_RATIO + ) c2w = torch.concatenate([R, pos[:, None]], dim=1) if self.ready and self.render_tab_state.preview_render: camera_type = self.render_tab_state.preview_camera_type @@ -327,13 +371,19 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: aspect=self.render_tab_state.preview_aspect, c2w=c2w, time=self.render_tab_state.preview_time, - camera_type=CameraType.PERSPECTIVE - if camera_type == "Perspective" - else CameraType.FISHEYE - if camera_type == "Fisheye" - else CameraType.EQUIRECTANGULAR - if camera_type == "Equirectangular" - else assert_never(camera_type), + camera_type=( + CameraType.PERSPECTIVE + if camera_type == "Perspective" + else ( + CameraType.FISHEYE + if camera_type == "Fisheye" + else ( + CameraType.EQUIRECTANGULAR + if camera_type == "Equirectangular" + else assert_never(camera_type) + ) + ) + ), idx=self.current_camera_idx, ) else: @@ -351,7 +401,9 @@ def handle_disconnect(self, client: viser.ClientHandle) -> None: self.render_statemachines.pop(client.client_id) def handle_new_client(self, client: viser.ClientHandle) -> None: - self.render_statemachines[client.client_id] = RenderStateMachine(self, VISER_NERFSTUDIO_SCALE_RATIO, client) + self.render_statemachines[client.client_id] = RenderStateMachine( + self, VISER_NERFSTUDIO_SCALE_RATIO, client + ) self.render_statemachines[client.client_id].start() @client.camera.on_update @@ -361,7 +413,9 @@ def _(_: viser.CameraHandle) -> None: self.last_move_time = time.time() with self.viser_server.atomic(): camera_state = self.get_camera_state(client) - self.render_statemachines[client.client_id].action(RenderAction("move", camera_state)) + self.render_statemachines[client.client_id].action( + RenderAction("move", camera_state) + ) def set_camera_visibility(self, visible: bool) -> None: """Toggle the visibility of the training cameras.""" @@ -382,15 +436,23 @@ def update_camera_poses(self): idxs = list(self.camera_handles.keys()) with torch.no_grad(): assert isinstance(camera_optimizer, CameraOptimizer) - c2ws_delta = camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device)).cpu().numpy() + c2ws_delta = ( + camera_optimizer(torch.tensor(idxs, device=camera_optimizer.device)) + .cpu() + .numpy() + ) for i, key in enumerate(idxs): # both are numpy arrays c2w_orig = self.original_c2w[key] c2w_delta = c2ws_delta[i, ...] - c2w = c2w_orig @ np.concatenate((c2w_delta, np.array([[0, 0, 0, 1]])), axis=0) + c2w = c2w_orig @ np.concatenate( + (c2w_delta, np.array([[0, 0, 0, 1]])), axis=0 + ) R = vtf.SO3.from_matrix(c2w[:3, :3]) # type: ignore R = R @ vtf.SO3.from_x_radians(np.pi) - self.camera_handles[key].position = c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO + self.camera_handles[key].position = ( + c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO + ) self.camera_handles[key].wxyz = R.wxyz def _trigger_rerender(self) -> None: @@ -430,7 +492,9 @@ def _pick_drawn_image_idxs(self, total_num: int) -> list[int]: else: num_display_images = min(self.config.max_num_display_images, total_num) # draw indices, roughly evenly spaced - return np.linspace(0, total_num - 1, num_display_images, dtype=np.int32).tolist() + return np.linspace( + 0, total_num - 1, num_display_images, dtype=np.int32 + ).tolist() def init_scene( self, @@ -474,7 +538,9 @@ def init_scene( ) def create_on_click_callback(capture_idx): - def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None: + def on_click_callback( + event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle], + ) -> None: with event.client.atomic(): event.client.camera.position = event.target.position event.client.camera.wxyz = event.target.wxyz @@ -504,12 +570,18 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N # this stops training while moving to make the response smoother while time.time() - self.last_move_time < 0.1: time.sleep(0.05) - if self.trainer is not None and self.trainer.training_state == "training" and self.train_util != 1: + if ( + self.trainer is not None + and self.trainer.training_state == "training" + and self.train_util != 1 + ): if ( EventName.TRAIN_RAYS_PER_SEC.value in GLOBAL_BUFFER["events"] and EventName.VIS_RAYS_PER_SEC.value in GLOBAL_BUFFER["events"] ): - train_s = GLOBAL_BUFFER["events"][EventName.TRAIN_RAYS_PER_SEC.value]["avg"] + train_s = GLOBAL_BUFFER["events"][EventName.TRAIN_RAYS_PER_SEC.value][ + "avg" + ] vis_s = GLOBAL_BUFFER["events"][EventName.VIS_RAYS_PER_SEC.value]["avg"] train_util = self.train_util vis_n = self.control_panel.max_res**2 @@ -517,7 +589,9 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N train_time = train_n / train_s vis_time = vis_n / vis_s - render_freq = train_util * vis_time / (train_time - train_util * train_time) + render_freq = ( + train_util * vis_time / (train_time - train_util * train_time) + ) else: render_freq = 30 if step > self.last_step + render_freq: @@ -526,7 +600,9 @@ def update_scene(self, step: int, num_rays_per_batch: Optional[int] = None) -> N for id in clients: camera_state = self.get_camera_state(clients[id]) if camera_state is not None: - self.render_statemachines[id].action(RenderAction("step", camera_state)) + self.render_statemachines[id].action( + RenderAction("step", camera_state) + ) self.update_camera_poses() self.update_step(step)