Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback, scene node removal API improvements #290

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Dict,
Generic,
Iterable,
Literal,
Tuple,
TypeVar,
cast,
Expand Down Expand Up @@ -106,6 +107,8 @@ class _GuiHandleState(Generic[T]):
sync_cb: Callable[[ClientId, dict[str, Any]], None] | None = None
"""Callback for synchronizing inputs across clients."""

removed: bool = False


class _OverridableGuiPropApi:
"""Mixin that allows reading/assigning properties defined in each scene node message."""
Expand Down Expand Up @@ -157,6 +160,17 @@ def __init__(self, _impl: _GuiHandleState[T]) -> None:

def remove(self) -> None:
"""Permanently remove this GUI element from the visualizer."""

# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Send remove to client(s) + update internal state.
self._impl.gui_api._websock_interface.queue_message(
GuiRemoveMessage(self._impl.id)
)
Expand Down Expand Up @@ -241,10 +255,25 @@ class GuiInputHandle(_GuiInputHandle[T], Generic[T]):
def on_update(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], Any]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a GUI input is updated. Happens in a thread."""
"""Attach a function to call when a GUI input is updated. Callbacks stack (need
to be manually removed via :meth:`remove_update_callback()`) and will be called
from a thread."""
self._impl.update_cb.append(func)
return func

def remove_update_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove update callbacks from the GUI input.

Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl.update_cb.clear()
else:
self._impl.update_cb = [cb for cb in self._impl.update_cb if cb != callback]


class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps):
"""Handle for checkbox inputs.
Expand Down Expand Up @@ -506,6 +535,16 @@ def __post_init__(self) -> None:

def remove(self) -> None:
"""Remove this tab group and all contained GUI elements."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Remove tabs, then self.
for tab in tuple(self._tab_handles):
tab.remove()
gui_api = self._impl.gui_api
Expand All @@ -524,6 +563,7 @@ class GuiTabHandle:
_children: dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
_removed: bool = False

def __enter__(self) -> GuiTabHandle:
self._container_id_restore = self._parent._impl.gui_api._get_container_id()
Expand All @@ -542,6 +582,15 @@ def __post_init__(self) -> None:
def remove(self) -> None:
"""Permanently remove this tab and all contained GUI elements from the
visualizer."""
# Warn if already removed.
if self._removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._removed = True

# We may want to make this thread-safe in the future.
found_index = -1
for i, tab in enumerate(self._parent._tab_handles):
Expand Down Expand Up @@ -594,6 +643,16 @@ def __exit__(self, *args) -> None:
def remove(self) -> None:
"""Permanently remove this folder and all contained GUI elements from the
visualizer."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Remove children, then self.
self._impl.gui_api._websock_interface.queue_message(
GuiRemoveMessage(self._impl.id)
)
Expand Down
2 changes: 0 additions & 2 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,6 @@ class MeshProps:
"""A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). Synchronized automatically when assigned."""
color: Optional[Tuple[int, int, int]]
"""Color of the mesh as RGB integers. Synchronized automatically when assigned."""
vertex_colors: Optional[npt.NDArray[np.uint8]]
"""Optional array of vertex colors. Synchronized automatically when assigned."""
wireframe: bool
"""Boolean indicating if the mesh should be rendered as a wireframe. Synchronized automatically when assigned."""
opacity: Optional[float]
Expand Down
12 changes: 10 additions & 2 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def add_mesh_skinned(
stacklevel=2,
)

assert len(bone_wxyzs) == len(bone_positions)
num_bones = len(bone_wxyzs)
assert skin_weights.shape == (vertices.shape[0], num_bones)

Expand All @@ -1059,7 +1060,6 @@ def add_mesh_skinned(
vertices=vertices.astype(np.float32),
faces=faces.astype(np.uint32),
color=_encode_rgb(color),
vertex_colors=None,
wireframe=wireframe,
opacity=opacity,
flat_shading=flat_shading,
Expand Down Expand Up @@ -1153,7 +1153,6 @@ def add_mesh_simple(
vertices=vertices.astype(np.float32),
faces=faces.astype(np.uint32),
color=_encode_rgb(color),
vertex_colors=None,
wireframe=wireframe,
opacity=opacity,
flat_shading=flat_shading,
Expand Down Expand Up @@ -1757,3 +1756,12 @@ def add_3d_gui_container(
self, message, name, wxyz, position, visible=visible
)
return Gui3dContainerHandle(node_handle._impl, gui_api, container_id)

def remove_by_name(self, name: str) -> None:
"""Helper to call `.remove()` on the scene node handles of the `name`
element or any of its children."""
handle_from_node_name = self._handle_from_node_name.copy()
name = name.rstrip("/") # '/parent/' => '/parent'
for node_name, handle in handle_from_node_name.items():
if node_name == name or node_name.startswith(name + "/"):
handle.remove()
57 changes: 48 additions & 9 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import dataclasses
import warnings
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand All @@ -17,7 +18,7 @@

import numpy as np
import numpy.typing as onpt
from typing_extensions import get_type_hints
from typing_extensions import Self, get_type_hints

from . import _messages
from .infra._infra import WebsockClientConnection, WebsockServer
Expand Down Expand Up @@ -123,10 +124,10 @@ class _SceneNodeHandleState:
default_factory=lambda: np.array([0.0, 0.0, 0.0])
)
visible: bool = True
# TODO: we should remove SceneNodeHandle as an argument here.
click_cb: list[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] | None = (
None
)
click_cb: list[
Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None]
] = dataclasses.field(default_factory=list)
removed: bool = False


class _SceneNodeMessage(Protocol):
Expand Down Expand Up @@ -223,6 +224,12 @@ def visible(self, visible: bool) -> None:

def remove(self) -> None:
"""Remove the node from the scene."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(f"Attempted to remove already removed node: {self.name}")
return

self._impl.removed = True
self._impl.api._websock_interface.queue_message(
_messages.RemoveSceneNodeMessage(self._impl.name)
)
Expand Down Expand Up @@ -253,18 +260,35 @@ class SceneNodePointerEvent(Generic[TSceneNodeHandle]):

class _ClickableSceneNodeHandle(SceneNodeHandle):
def on_click(
self: TSceneNodeHandle,
func: Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None],
) -> Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None]:
self: Self,
func: Callable[[SceneNodePointerEvent[Self]], None],
) -> Callable[[SceneNodePointerEvent[Self]], None]:
"""Attach a callback for when a scene node is clicked."""
self._impl.api._websock_interface.queue_message(
_messages.SetSceneNodeClickableMessage(self._impl.name, True)
)
if self._impl.click_cb is None:
self._impl.click_cb = []
self._impl.click_cb.append(func) # type: ignore
self._impl.click_cb.append(
cast(
Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None], func
)
)
return func

def remove_click_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove click callbacks from scene node.

Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl.click_cb.clear()
else:
self._impl.click_cb = [cb for cb in self._impl.click_cb if cb != callback]


class CameraFrustumHandle(
_ClickableSceneNodeHandle,
Expand Down Expand Up @@ -510,6 +534,21 @@ def on_update(
self._impl_aux.update_cb.append(func)
return func

def remove_update_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove update callbacks from the transform controls.

Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl_aux.update_cb.clear()
else:
self._impl_aux.update_cb = [
cb for cb in self._impl_aux.update_cb if cb != callback
]


class Gui3dContainerHandle(
SceneNodeHandle,
Expand Down
8 changes: 4 additions & 4 deletions src/viser/client/src/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ export default function ControlPanel(props: {
controlWidthString == "small"
? "16em"
: controlWidthString == "medium"
? "20em"
: controlWidthString == "large"
? "24em"
: null
? "20em"
: controlWidthString == "large"
? "24em"
: null
)!;

const generatedServerToggleButton = (
Expand Down
24 changes: 14 additions & 10 deletions src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,20 @@ function useObjectFactory(message: SceneNodeMessage | undefined): {
message.props.plane == "xz"
? new THREE.Euler(0.0, 0.0, 0.0)
: message.props.plane == "xy"
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.props.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.props.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.props.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.props.plane == "zy"
? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0)
: undefined
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.props.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.props.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.props.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.props.plane == "zy"
? new THREE.Euler(
-Math.PI / 2.0,
0.0,
-Math.PI / 2.0,
)
: undefined
}
/>
</group>
Expand Down
40 changes: 10 additions & 30 deletions src/viser/client/src/ThreeAssets.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -408,19 +408,6 @@ export const InstancedAxes = React.forwardRef<
});

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
return new THREE.Float32BufferAttribute(
new Float32Array(new Uint8Array(colors)).map((value) => {
value = value / 255.0;
if (value <= 0.04045) {
return value / 12.92;
} else {
return Math.pow((value + 0.055) / 1.055, 2.4);
}
}),
3,
);
}
export const ViserMesh = React.forwardRef<
THREE.Mesh | THREE.SkinnedMesh,
MeshMessage | SkinnedMeshMessage
Expand Down Expand Up @@ -448,7 +435,6 @@ export const ViserMesh = React.forwardRef<
const standardArgs = {
color:
message.props.color === null ? undefined : rgbToInt(message.props.color),
vertexColors: message.props.vertex_colors !== null,
wireframe: message.props.wireframe,
transparent: message.props.opacity !== null,
opacity: message.props.opacity ?? 1.0,
Expand All @@ -474,16 +460,16 @@ export const ViserMesh = React.forwardRef<
message.props.material == "standard" || message.props.wireframe
? new THREE.MeshStandardMaterial(standardArgs)
: message.props.material == "toon3"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.props.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.props.material);
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.props.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.props.material);
const geometry = new THREE.BufferGeometry();
geometry.setAttribute(
"position",
Expand All @@ -498,12 +484,6 @@ export const ViserMesh = React.forwardRef<
3,
),
);
if (message.props.vertex_colors !== null) {
geometry.setAttribute(
"color",
threeColorBufferFromUint8Buffer(message.props.vertex_colors),
);
}

geometry.setIndex(
new THREE.Uint32BufferAttribute(
Expand Down
Loading
Loading