Skip to content

Commit

Permalink
diff_traj: added linear encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 6, 2024
1 parent d3aff48 commit e692d16
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 29 deletions.
121 changes: 103 additions & 18 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

import torch
import torch.nn.functional as F

from diffusers import DDPMScheduler
from torch import nn
from torch.optim.swa_utils import AveragedModel
from torch.utils.tensorboard import SummaryWriter

from diffusers import DDPMScheduler
from torchtune.modules import RotaryPositionalEmbeddings

from torchdrive.amp import autocast
from torchdrive.autograd import autograd_context, register_log_grad_norm
from torchdrive.data import Batch
from torchdrive.losses import losses_backward
from torchdrive.tasks.van import Van
from torchtune.modules import RotaryPositionalEmbeddings
from torchworld.models.vit import MaskViT
from torchworld.transforms.img import normalize_img, normalize_mask, render_color
from torchworld.transforms.mask import random_block_mask, true_mask
Expand All @@ -37,7 +37,9 @@ def __init__(
self.cam_shape = cam_shape
self.dim = dim
self.num_heads = num_heads
self.positional_embedding = RotaryPositionalEmbeddings(dim//num_heads, max_seq_len=max_seq_len)
self.positional_embedding = RotaryPositionalEmbeddings(
dim // num_heads, max_seq_len=max_seq_len
)

layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
Expand Down Expand Up @@ -74,6 +76,7 @@ def forward(self, input: torch.Tensor, condition: torch.Tensor):

return x


class XYEmbedding(nn.Module):
def __init__(self, shape: Tuple[int, int], scale: float, dim: int):
"""
Expand All @@ -89,9 +92,7 @@ def __init__(self, shape: Tuple[int, int], scale: float, dim: int):
self.scale = scale
self.shape = shape

self.embedding = nn.Parameter(
torch.empty(*shape, dim).normal_(std=0.02)
)
self.embedding = nn.Parameter(torch.empty(*shape, dim).normal_(std=0.02))

def forward(self, pos: torch.Tensor):
"""
Expand All @@ -102,13 +103,13 @@ def forward(self, pos: torch.Tensor):
the embedding of the position (..., dim)
"""

dx = (self.shape[0]-1) // 2
dy = (self.shape[1]-1) // 2
dx = (self.shape[0] - 1) // 2
dy = (self.shape[1] - 1) // 2
x = (pos[..., 0] * dx / self.scale + dx).long()
y = (pos[..., 1] * dy / self.scale + dy).long()

x = x.clamp(min=0, max=self.shape[0]-1)
y = y.clamp(min=0, max=self.shape[1]-1)
x = x.clamp(min=0, max=self.shape[0] - 1)
y = y.clamp(min=0, max=self.shape[1] - 1)

return self.embedding[x, y]

Expand All @@ -118,7 +119,7 @@ def decode(self, input: torch.Tensor) -> torch.Tensor:
Args:
input: input embedding to decode (bs, seq_len, dim)
Returns:
the position (bs, seq_len, 2)
"""
Expand All @@ -136,23 +137,103 @@ def decode(self, input: torch.Tensor) -> torch.Tensor:
x = torch.div(classes, self.shape[1], rounding_mode="floor")
y = torch.remainder(classes, self.shape[1])

dx = (self.shape[0]-1) // 2
dy = (self.shape[1]-1) // 2
dx = (self.shape[0] - 1) // 2
dy = (self.shape[1] - 1) // 2

x = (x.float() - dx) * self.scale / dx
y = (y.float() - dy) * self.scale / dy

# 2x (bs, seq_len) -> (bs, seq_len, 2)
return torch.stack([x, y], dim=-1)



class XEmbedding(nn.Module):
def __init__(self, shape: int, scale: float, dim: int):
"""
Initialize the XEmbedding, which is a linear embedding.
Arguments:
shape: the size of the embedding grid [x], the center is 0.0
scale: the max coordinate value
dim: dimension of the embedding
"""
super().__init__()

self.scale = scale
self.shape = shape

self.embedding = nn.Parameter(torch.empty(shape, dim).normal_(std=0.02))

def forward(self, pos: torch.Tensor):
"""
Args:
pos: the list of positions (...)
Returns:
the embedding of the position (..., dim)
"""

dx = (self.shape - 1) // 2
x = (pos * dx / self.scale + dx).long()

x = x.clamp(min=0, max=self.shape - 1)

return self.embedding[x]

def decode(self, input: torch.Tensor) -> torch.Tensor:
"""
Convert the embedding back to the position using a cosine similarity distance function.
Args:
input: input embedding to decode (bs, seq_len, dim)
Returns:
the position (bs, seq_len)
"""

# (bs, seq_len, dim) @ (x, dim) -> (bs, seq_len, x)
similarity = torch.einsum("bsd,xd->bsx", input, self.embedding)

x = similarity.argmax(dim=-1)

dx = (self.shape - 1) // 2
x = (x.float() - dx) * self.scale / dx
return x


class XYLinearEmbedding(nn.Module):
def __init__(self, shape: Tuple[int, int], scale: float, dim: int):
"""
Initialize the XYLinearEmbedding which is a 2d embedding comprised of two linear XEmbeddings.
Arguments:
shape: the size of the embedding grid [x, y], the center is 0.0
scale: the max coordinate value
dim: dimension of the embedding (split in 2 for the two child embeddings)
"""
super().__init__()

self.dim = dim // 2

self.x = XEmbedding(shape[0], scale, dim // 2)
self.y = XEmbedding(shape[1], scale, dim // 2)

def forward(self, pos: torch.Tensor):
x = self.x(pos[..., 0])
y = self.y(pos[..., 1])
return torch.cat([x, y], dim=-1)

def decode(self, input: torch.Tensor) -> torch.Tensor:
x = self.x.decode(input[..., : self.dim])
y = self.y.decode(input[..., self.dim :])
return torch.stack([x, y], dim=-1)


class DiffTraj(nn.Module, Van):
"""
A diffusion model for trajectory detection.
"""

def __init__(
self,
cameras: List[str],
Expand Down Expand Up @@ -181,7 +262,13 @@ def __init__(
}
)

self.trajectory_
# embedding
#
self.xy_embedding = XYEmbedding(
shape=(128, 128),
scale=100, # 100 meters
dim=dim,
)

self.decoder = Decoder(
max_seq_len=20,
Expand Down Expand Up @@ -280,8 +367,6 @@ def forward(

input_tokens = torch.cat(all_feats, dim=1)



noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
Expand Down
76 changes: 65 additions & 11 deletions torchdrive/tasks/test_diff_traj.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,84 @@
from torchdrive.tasks.diff_traj import XYEmbedding

import unittest

import torch
from torchdrive.tasks.diff_traj import XEmbedding, XYEmbedding, XYLinearEmbedding


class TestDiffTraj(unittest.TestCase):
def test_diff_traj(self):
dim = 20
def test_xy_embedding(self):
torch.manual_seed(0)

dim = 32

traj = XYEmbedding(
shape=(16, 24),
dim=dim,
scale=1.0,
)

input = torch.tensor([
(0.0, 0.0),
(1.0, 0.0),
(0.0, 1.0),
(-1.0, 0.0),
(0.0, -1.0),
]).unsqueeze(0)
input = torch.tensor(
[
(0.0, 0.0),
(1.0, 0.0),
(0.0, 1.0),
(-1.0, 0.0),
(0.0, -1.0),
]
).unsqueeze(0)

output = traj(input)
self.assertEqual(output.shape, (1, 5, dim))

positions = traj.decode(output)
self.assertEqual(positions.shape, (1, 5, 2))
torch.testing.assert_close(positions, input)

def test_xy_linear_embedding(self):
torch.manual_seed(0)

dim = 32

traj = XYLinearEmbedding(
shape=(16, 24),
dim=dim,
scale=1.0,
)

input = torch.tensor(
[
(0.0, 0.0),
(1.0, 0.0),
(0.0, 1.0),
(-1.0, 0.0),
(0.0, -1.0),
]
).unsqueeze(0)

output = traj(input)
self.assertEqual(output.shape, (1, 5, dim))

positions = traj.decode(output)
self.assertEqual(positions.shape, (1, 5, 2))
torch.testing.assert_close(positions, input)

def test_x_embedding(self):
torch.manual_seed(0)

dim = 20

traj = XEmbedding(shape=16, dim=dim, scale=1.0)

input = torch.tensor(
[
0.0,
-1.0,
1.0,
]
).unsqueeze(0)

output = traj(input)
self.assertEqual(output.shape, (1, 3, dim))

positions = traj.decode(output)
self.assertEqual(positions.shape, (1, 3))
torch.testing.assert_close(positions, input)

0 comments on commit e692d16

Please sign in to comment.