Skip to content

Commit

Permalink
torchworld/points: made more general
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 2, 2023
1 parent b048dd5 commit b51c6ae
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 28 deletions.
14 changes: 0 additions & 14 deletions torchworld/structures/lidar.py

This file was deleted.

16 changes: 16 additions & 0 deletions torchworld/structures/points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass

import torch


@dataclass
class Points:
"""Points represents a set of points in world coordinates.
Attributes
----------
data: [bs, 3+, num_points]
World coordinates with metadata [x, y, z, ... (intensity?)]
"""

data: torch.Tensor
8 changes: 4 additions & 4 deletions torchworld/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchworld.structures.cameras import PerspectiveCameras
from torchworld.structures.grid import Grid3d, GridImage
from torchworld.structures.lidar import Lidar
from torchworld.structures.points import Points
from torchworld.transforms.transform3d import Transform3d


Expand Down Expand Up @@ -43,7 +43,7 @@ def test_path(self) -> None:
out = vis.path(positions)
self.assertIsInstance(out, pythreejs.Group)

def test_lidar(self) -> None:
lidar = Lidar(data=torch.rand(1, 4, 10))
out = vis.lidar(lidar)
def test_points(self) -> None:
points = Points(data=torch.rand(1, 4, 10))
out = vis.points(points)
self.assertIsInstance(out, pythreejs.Points)
21 changes: 11 additions & 10 deletions torchworld/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchworld.structures.cameras import CamerasBase
from torchworld.structures.grid import Grid3d, GridImage
from torchworld.structures.lidar import Lidar
from torchworld.structures.points import Points
from torchworld.transforms.img import normalize_img_cuda
from torchworld.transforms.transform3d import RotateAxisAngle, Transform3d

Expand Down Expand Up @@ -368,18 +368,19 @@ def path(positions: Transform3d) -> pythreejs.Object3D:
return group


def lidar(data: Lidar) -> pythreejs.Object3D:
"""lidar returns a object that renders the lidar data.
def points(points: Points) -> pythreejs.Object3D:
"""points returns a object that renders the points data. This only uses the
first 3 channels of the data [x, y, z].
Arguments
---------
data: The Lidar data
points: Points
"""
if data.data.size(0) != 1:
raise TypeError("lidar must have batch size 1")
if points.data.size(0) != 1:
raise TypeError("points must have batch size 1")

lidar_geo = pythreejs.BufferGeometry(
attributes={"position": pythreejs.BufferAttribute(data.data[0, :3].T.numpy())}
points_geo = pythreejs.BufferGeometry(
attributes={"position": pythreejs.BufferAttribute(points.data[0, :3].T.numpy())}
)
lidar_mat = pythreejs.PointsMaterial(color="white", size=1, sizeAttenuation=False)
return pythreejs.Points(lidar_geo, lidar_mat)
points_mat = pythreejs.PointsMaterial(color="white", size=1, sizeAttenuation=False)
return pythreejs.Points(points_geo, points_mat)

0 comments on commit b51c6ae

Please sign in to comment.