Skip to content

Commit

Permalink
restructure major features out of nerfstudio, minor changes for usabi…
Browse files Browse the repository at this point in the history
…lity
  • Loading branch information
ginazhouhuiwu committed Sep 27, 2024
1 parent 610e4e1 commit 9a60531
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 253 deletions.
219 changes: 6 additions & 213 deletions nerfstudio/viewer/render_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading
import time
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import splines
Expand All @@ -33,6 +33,9 @@
from nerfstudio.models.base_model import Model
from nerfstudio.viewer.control_panel import ControlPanel

if TYPE_CHECKING:
from viser import GuiInputHandle


@dataclasses.dataclass
class Keyframe:
Expand Down Expand Up @@ -523,7 +526,7 @@ def populate_render_tab(
datapath: Path,
viewer_model: Model,
control_panel: Optional[ControlPanel] = None,
) -> RenderTabState:
) -> Tuple[RenderTabState, GuiInputHandle, GuiInputHandle, GuiInputHandle]:
from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO

render_tab_state = RenderTabState(
Expand Down Expand Up @@ -1164,216 +1167,6 @@ def _(event: viser.GuiEvent) -> None:
@close_button.on_click
def _(_) -> None:
modal.close()

auto_camera_folder = server.gui.add_folder("Automatic Camera Path")
with auto_camera_folder:
click_position = np.array([0.0, 0.0, -5.0])
select_center_button = server.gui.add_button(
"Select Center",
icon=viser.Icon.CROSSHAIR,
hint="Choose center point to generate camera path around.",
)

@select_center_button.on_click
def _(event: viser.GuiEvent) -> None:
select_center_button.disabled = True

@event.client.scene.on_pointer_event(event_type="click")
def _(event: viser.ScenePointerEvent) -> None:
# Code mostly borrowed from garfield.studio!
import torch
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.model_components.losses import scale_gradients_by_distance_squared

origin = torch.tensor(event.ray_origin).view(1, 3)
direction = torch.tensor(event.ray_direction).view(1, 3)

# Get intersection
bundle = RayBundle(
origins=origin,
directions=direction,
pixel_area=torch.tensor(0.001).view(1, 1),
camera_indices=torch.tensor(0).view(1, 1),
nears=torch.tensor(0.05).view(1, 1),
fars=torch.tensor(100).view(1, 1),
).to("cuda")

# Get the distance/depth to the intersection --> calculate 3D position of the click
ray_samples, _, _ = viewer_model.proposal_sampler(bundle, density_fns=viewer_model.density_fns)
field_outputs = viewer_model.field.forward(ray_samples, compute_normals=viewer_model.config.predict_normals)
if viewer_model.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)
weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
with torch.no_grad():
depth = viewer_model.renderer_depth(weights=weights, ray_samples=ray_samples)
distance = depth[0, 0].detach().cpu().numpy()

nonlocal click_position
click_position = np.array(origin + direction * distance).reshape(3,)

server.scene.add_icosphere(
f"/render_center_pos",
radius=0.1,
color=(200, 10, 30),
position=click_position,
)

event.client.scene.remove_pointer_callback()

@event.client.scene.on_pointer_callback_removed
def _():
select_center_button.disabled = False

num_cameras_handle = server.gui.add_number(
label="Number of Cameras",
initial_value=3,
hint="Total number of cameras generated in path, placed equidistant from neighboring ones.",
)

radius_handle = server.gui.add_number(
label="Radius",
initial_value=4,
hint="Radius of circular camera path.",
)

camera_height_handle = server.gui.add_number(
label="Height",
initial_value=2,
hint="Height of cameras with respect to chosen origin.",
)

circular_camera_path_button = server.gui.add_button(
"Generate Circular Camera Path",
icon=viser.Icon.CAMERA,
hint="Automatically generate a circular camera path around selected point.",
)

def wxyz_helper(camera_position: np.ndarray) -> np.ndarray:
# Calculates the camera direction from position to click_position
camera_direction = camera_position - click_position
camera_direction = camera_direction / np.linalg.norm(camera_direction)

global_up = np.array([0.0, 0.0, 1.0])

camera_right = np.cross(camera_direction, global_up)
camera_right_norm = np.linalg.norm(camera_right)
if camera_right_norm > 0:
camera_right = camera_right / camera_right_norm

camera_up = np.cross(camera_right, camera_direction)

R = np.array([-camera_right, -camera_up, -camera_direction]).T

w = np.sqrt(1 + R[0, 0] + R[1, 1] + R[2, 2]) / 2
x = (R[2, 1] - R[1, 2]) / (4 * w)
y = (R[0, 2] - R[2, 0]) / (4 * w)
z = (R[1, 0] - R[0, 1]) / (4 * w)
return np.array([w, x, y, z])
else:
return np.array([1.0, 0.0, 0.0, 0.0])

camera_coords = []
@circular_camera_path_button.on_click
def _(event: viser.GuiEvent) -> None:
nonlocal click_position, num_cameras_handle, radius_handle, camera_height_handle, camera_coords
num_cameras = num_cameras_handle.value
radius = radius_handle.value
camera_height = camera_height_handle.value

camera_coords = []
for i in range(num_cameras):
camera_coords.append(click_position +
np.array([radius * np.cos(2 * np.pi * i / num_cameras),
radius * np.sin(2 * np.pi * i/ num_cameras),
camera_height]))

fov = event.client.camera.fov
for i, position in enumerate(camera_coords):
camera_path.add_camera(
keyframe=Keyframe(
position=position,
wxyz=wxyz_helper(position),
override_fov_enabled=False,
override_fov_rad=fov,
override_time_enabled=False,
override_time_val=0.0,
aspect=resolution.value[0] / resolution.value[1],
override_transition_enabled=False,
override_transition_sec=None,
),
keyframe_index = i,
)
duration_number.value = camera_path.compute_duration()
camera_path.update_spline()

optimize_button = server.gui.add_button(
"Optimize Camera Path",
icon=viser.Icon.CAMERA,
hint="Optimizes camera path for object avoidance iteratively.",
)

@optimize_button.on_click
def _(event: viser.GuiEvent) -> None:
import torch
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.model_components.losses import scale_gradients_by_distance_squared

nonlocal camera_coords

directions = [[1, 0, 0],
[-1, 0, 0],
[0, 1, 0],
[0, -1, 0],
[0, 0, 1],
[0, 0, -1]]

for i, position in enumerate(camera_coords):
raylen = 2.0
origins = torch.tensor(np.tile(position, (6, 1)))
pixel_area = torch.ones_like(origins[..., 0:1])
camera_indices = torch.zeros_like(origins[..., 0:1]).int()
nears = torch.zeros_like(origins[..., 0:1])
fars = torch.ones_like(origins[..., 0:1]) * raylen
directions_norm = torch.ones_like(origins[..., 0:1])
viewer_model.training = False

bundle = RayBundle(
origins=origins,
directions=torch.tensor(directions),
pixel_area=pixel_area,
camera_indices=camera_indices,
nears=nears,
fars=fars,
metadata={"directions_norm": directions_norm},
).to('cuda')

outputs = viewer_model.get_outputs(bundle)
distances = outputs["expected_depth"].detach().cpu().numpy()

loss = -min(distances)
if loss > -0.4:
position = position - directions[np.argmin(distances)] * 1
# backprop through the nerf as the gradient step, input is position
camera_path.add_camera(
keyframe=Keyframe(
position=position,
wxyz=wxyz_helper(position),
override_fov_enabled=False,
override_fov_rad=event.client.camera.fov,
override_time_enabled=False,
override_time_val=0.0,
aspect=resolution.value[0] / resolution.value[1],
override_transition_enabled=False,
override_transition_sec=None,
),
keyframe_index = i,
)
duration_number.value = camera_path.compute_duration()
camera_path.update_spline()

camera_coords[i] = position

if control_panel is not None:
camera_path = CameraPath(server, duration_number, control_panel._time_enabled)
Expand All @@ -1383,7 +1176,7 @@ def _(event: viser.GuiEvent) -> None:
camera_path.default_fov = fov_degrees.value / 180.0 * np.pi
camera_path.default_transition_sec = transition_sec_number.value

return render_tab_state
return render_tab_state, camera_path, duration_number, resolution


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 9a60531

Please sign in to comment.