Skip to content

Commit

Permalink
diff_traj: random_traj fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Aug 4, 2024
1 parent f449160 commit 0fd7b40
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
29 changes: 15 additions & 14 deletions notebooks/compute_dream_pos.ipynb

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,14 +722,24 @@ def forward(
def random_traj(
BS: int, seq_len: int, device: object, vel: torch.Tensor
) -> torch.Tensor:
"""Generates a random trajectory at the specified velocity."""
"""Generates a random trajectory at the specified velocity.
Arguments:
BS: batch size
vel: [BS, 1]
Returns:
The random trajectory [BS, seq_len, 2]
"""

# scale from 0.5 to 1.5
speed = (torch.rand(BS, device=device) + 0.5) * vel
speed = (torch.rand(BS, device=device) + 0.5) * vel.squeeze(1)

angle = torch.rand(BS, device=device) * math.pi
x = torch.sin(angle) * torch.arange(seq_len, device=device) / 2 * speed
y = torch.cos(angle) * torch.arange(seq_len, device=device) / 2 * speed
angle = torch.rand(BS, 1, device=device) * math.pi
x = torch.arange(seq_len, device=device) / 2 * speed.unsqueeze(1)
x *= torch.sin(angle)
y = torch.arange(seq_len, device=device) / 2 * speed.unsqueeze(1)
y *= torch.cos(angle)

traj = torch.stack([x, y], dim=-1)
return traj
Expand Down
9 changes: 8 additions & 1 deletion torchdrive/tasks/test_diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchdrive.tasks.diff_traj import (
compute_dream_pos,
DiffTraj,
random_traj,
square_mask,
XEmbedding,
XYEmbedding,
Expand Down Expand Up @@ -218,7 +219,6 @@ def test_square_mask(self):
torch.testing.assert_close(output[:2], target)

def test_compute_dream_pos(self):

positions = torch.rand(2, 18, 2)
mask = torch.ones(2, 18)
pred_traj = torch.rand(2, 18, 2)
Expand All @@ -230,3 +230,10 @@ def test_compute_dream_pos(self):
self.assertEqual(dream_mask.shape, (2, 16))
self.assertEqual(dream_positions.shape, (2, 16, 2))
self.assertEqual(dream_pred.shape, (2, 16, 2))

def test_random_traj(self):
BS = 10
vel = torch.ones(BS, 1)
seq_len = 18
traj = random_traj(BS=BS, seq_len=seq_len, device="cpu", vel=vel)
self.assertEqual(traj.shape, (BS, seq_len, 2))

0 comments on commit 0fd7b40

Please sign in to comment.