Skip to content

Commit

Permalink
init keyframe value capture for image embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
ginazhouhuiwu committed Oct 10, 2024
1 parent cea20a7 commit 8bed243
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions nerfstudio/viewer/render_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import json
import threading
import time
import torch

from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -30,7 +32,10 @@
import viser.transforms as tf
from scipy import interpolate

from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle

Check failure on line 35 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

nerfstudio/viewer/render_panel.py:35:61: F401 `nerfstudio.cameras.cameras.RayBundle` imported but unused
from nerfstudio.cameras.camera_utils import quaternion_matrix
from nerfstudio.models.base_model import Model
from nerfstudio.utils import colormaps

Check failure on line 38 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

nerfstudio/viewer/render_panel.py:38:30: F401 `nerfstudio.utils.colormaps` imported but unused
from nerfstudio.viewer.control_panel import ControlPanel

if TYPE_CHECKING:

Check failure on line 41 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/viewer/render_panel.py:15:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -66,9 +71,14 @@ def from_camera(camera: viser.CameraHandle, aspect: float) -> Keyframe:

class CameraPath:
def __init__(
self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float], time_enabled: bool = False
self,
server: viser.ViserServer,
duration_element: viser.GuiInputHandle[float],
viewer_model: Model,
time_enabled: bool = False
):
self._server = server
self._model = viewer_model
self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {}
self._keyframe_counter: int = 0
self._spline_nodes: List[viser.SceneNodeHandle] = []
Expand Down Expand Up @@ -167,6 +177,7 @@ def _(_) -> None:

delete_button = server.gui.add_button("Delete", color="red", icon=viser.Icon.TRASH)
go_to_button = server.gui.add_button("Go to")
capture_button = server.gui.add_button("Capture")
close_button = server.gui.add_button("Close")

@override_fov.on_update
Expand Down Expand Up @@ -232,6 +243,42 @@ def _(event: viser.GuiEvent) -> None:
client.camera.wxyz = T_world_set.rotation().wxyz
client.camera.position = T_world_set.translation()
time.sleep(1.0 / 30.0)

@capture_button.on_click
def _(event: viser.GuiEvent) -> None:
R = quaternion_matrix(keyframe.wxyz)
forward_vector = np.array([0, 0, 1, 1])
direction = R @ forward_vector
direction = [direction[:3]]

image_height = 607.0
image_width = 1060.0

from nerfstudio.viewer.viewer import VISER_NERFSTUDIO_SCALE_RATIO

c2w = tf.SE3.from_rotation_and_translation(
tf.SO3(keyframe.wxyz) @ tf.SO3.from_x_radians(np.pi),
keyframe.position / VISER_NERFSTUDIO_SCALE_RATIO,
).as_matrix()
c2w = torch.tensor(c2w[:3])

camera = Cameras(
fx=image_width / 2,
fy=image_height,
cx=image_width / 2,
cy=image_height / 2,
camera_to_worlds=c2w,
camera_type=CameraType.PERSPECTIVE,
times=None,
)

self._model.training = False
outputs = self._model.get_outputs_for_camera(camera)

from PIL import Image
_im = outputs['rgb'].detach().numpy()

Check failure on line 279 in nerfstudio/viewer/render_panel.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/viewer/render_panel.py:278:1: I001 Import block is un-sorted or un-formatted
im = Image.fromarray((_im * 255).astype(np.uint8))
im.save("capture.png")

@close_button.on_click
def _(_) -> None:
Expand Down Expand Up @@ -1169,9 +1216,9 @@ def _(_) -> None:
modal.close()

if control_panel is not None:
camera_path = CameraPath(server, duration_number, control_panel._time_enabled)
camera_path = CameraPath(server, duration_number, viewer_model, control_panel._time_enabled)
else:
camera_path = CameraPath(server, duration_number)
camera_path = CameraPath(server, duration_number, viewer_model)
camera_path.tension = tension_slider.value
camera_path.default_fov = fov_degrees.value / 180.0 * np.pi
camera_path.default_transition_sec = transition_sec_number.value
Expand Down

0 comments on commit 8bed243

Please sign in to comment.