From e692d16ffedc6bf94b2bdf8332eedd615dd7d35c Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 6 Jul 2024 16:33:42 -0700 Subject: [PATCH] diff_traj: added linear encodings --- torchdrive/tasks/diff_traj.py | 121 ++++++++++++++++++++++++----- torchdrive/tasks/test_diff_traj.py | 76 +++++++++++++++--- 2 files changed, 168 insertions(+), 29 deletions(-) diff --git a/torchdrive/tasks/diff_traj.py b/torchdrive/tasks/diff_traj.py index 6b17ada..09d5bcd 100644 --- a/torchdrive/tasks/diff_traj.py +++ b/torchdrive/tasks/diff_traj.py @@ -3,18 +3,18 @@ import torch import torch.nn.functional as F + +from diffusers import DDPMScheduler from torch import nn from torch.optim.swa_utils import AveragedModel from torch.utils.tensorboard import SummaryWriter -from diffusers import DDPMScheduler -from torchtune.modules import RotaryPositionalEmbeddings - from torchdrive.amp import autocast from torchdrive.autograd import autograd_context, register_log_grad_norm from torchdrive.data import Batch from torchdrive.losses import losses_backward from torchdrive.tasks.van import Van +from torchtune.modules import RotaryPositionalEmbeddings from torchworld.models.vit import MaskViT from torchworld.transforms.img import normalize_img, normalize_mask, render_color from torchworld.transforms.mask import random_block_mask, true_mask @@ -37,7 +37,9 @@ def __init__( self.cam_shape = cam_shape self.dim = dim self.num_heads = num_heads - self.positional_embedding = RotaryPositionalEmbeddings(dim//num_heads, max_seq_len=max_seq_len) + self.positional_embedding = RotaryPositionalEmbeddings( + dim // num_heads, max_seq_len=max_seq_len + ) layers: OrderedDict[str, nn.Module] = OrderedDict() for i in range(num_layers): @@ -74,6 +76,7 @@ def forward(self, input: torch.Tensor, condition: torch.Tensor): return x + class XYEmbedding(nn.Module): def __init__(self, shape: Tuple[int, int], scale: float, dim: int): """ @@ -89,9 +92,7 @@ def __init__(self, shape: Tuple[int, int], scale: float, dim: int): self.scale = scale self.shape = shape - self.embedding = nn.Parameter( - torch.empty(*shape, dim).normal_(std=0.02) - ) + self.embedding = nn.Parameter(torch.empty(*shape, dim).normal_(std=0.02)) def forward(self, pos: torch.Tensor): """ @@ -102,13 +103,13 @@ def forward(self, pos: torch.Tensor): the embedding of the position (..., dim) """ - dx = (self.shape[0]-1) // 2 - dy = (self.shape[1]-1) // 2 + dx = (self.shape[0] - 1) // 2 + dy = (self.shape[1] - 1) // 2 x = (pos[..., 0] * dx / self.scale + dx).long() y = (pos[..., 1] * dy / self.scale + dy).long() - x = x.clamp(min=0, max=self.shape[0]-1) - y = y.clamp(min=0, max=self.shape[1]-1) + x = x.clamp(min=0, max=self.shape[0] - 1) + y = y.clamp(min=0, max=self.shape[1] - 1) return self.embedding[x, y] @@ -118,7 +119,7 @@ def decode(self, input: torch.Tensor) -> torch.Tensor: Args: input: input embedding to decode (bs, seq_len, dim) - + Returns: the position (bs, seq_len, 2) """ @@ -136,8 +137,8 @@ def decode(self, input: torch.Tensor) -> torch.Tensor: x = torch.div(classes, self.shape[1], rounding_mode="floor") y = torch.remainder(classes, self.shape[1]) - dx = (self.shape[0]-1) // 2 - dy = (self.shape[1]-1) // 2 + dx = (self.shape[0] - 1) // 2 + dy = (self.shape[1] - 1) // 2 x = (x.float() - dx) * self.scale / dx y = (y.float() - dy) * self.scale / dy @@ -145,14 +146,94 @@ def decode(self, input: torch.Tensor) -> torch.Tensor: # 2x (bs, seq_len) -> (bs, seq_len, 2) return torch.stack([x, y], dim=-1) - +class XEmbedding(nn.Module): + def __init__(self, shape: int, scale: float, dim: int): + """ + Initialize the XEmbedding, which is a linear embedding. + + Arguments: + shape: the size of the embedding grid [x], the center is 0.0 + scale: the max coordinate value + dim: dimension of the embedding + """ + super().__init__() + + self.scale = scale + self.shape = shape + + self.embedding = nn.Parameter(torch.empty(shape, dim).normal_(std=0.02)) + + def forward(self, pos: torch.Tensor): + """ + Args: + pos: the list of positions (...) + + Returns: + the embedding of the position (..., dim) + """ + + dx = (self.shape - 1) // 2 + x = (pos * dx / self.scale + dx).long() + + x = x.clamp(min=0, max=self.shape - 1) + + return self.embedding[x] + + def decode(self, input: torch.Tensor) -> torch.Tensor: + """ + Convert the embedding back to the position using a cosine similarity distance function. + + Args: + input: input embedding to decode (bs, seq_len, dim) + + Returns: + the position (bs, seq_len) + """ + + # (bs, seq_len, dim) @ (x, dim) -> (bs, seq_len, x) + similarity = torch.einsum("bsd,xd->bsx", input, self.embedding) + + x = similarity.argmax(dim=-1) + + dx = (self.shape - 1) // 2 + x = (x.float() - dx) * self.scale / dx + return x + + +class XYLinearEmbedding(nn.Module): + def __init__(self, shape: Tuple[int, int], scale: float, dim: int): + """ + Initialize the XYLinearEmbedding which is a 2d embedding comprised of two linear XEmbeddings. + + Arguments: + shape: the size of the embedding grid [x, y], the center is 0.0 + scale: the max coordinate value + dim: dimension of the embedding (split in 2 for the two child embeddings) + """ + super().__init__() + + self.dim = dim // 2 + + self.x = XEmbedding(shape[0], scale, dim // 2) + self.y = XEmbedding(shape[1], scale, dim // 2) + + def forward(self, pos: torch.Tensor): + x = self.x(pos[..., 0]) + y = self.y(pos[..., 1]) + return torch.cat([x, y], dim=-1) + + def decode(self, input: torch.Tensor) -> torch.Tensor: + x = self.x.decode(input[..., : self.dim]) + y = self.y.decode(input[..., self.dim :]) + return torch.stack([x, y], dim=-1) class DiffTraj(nn.Module, Van): """ A diffusion model for trajectory detection. """ + def __init__( self, cameras: List[str], @@ -181,7 +262,13 @@ def __init__( } ) - self.trajectory_ + # embedding + # + self.xy_embedding = XYEmbedding( + shape=(128, 128), + scale=100, # 100 meters + dim=dim, + ) self.decoder = Decoder( max_seq_len=20, @@ -280,8 +367,6 @@ def forward( input_tokens = torch.cat(all_feats, dim=1) - - noise_scheduler = DDPMScheduler(num_train_timesteps=1000) noise = torch.randn(sample_image.shape) timesteps = torch.LongTensor([50]) diff --git a/torchdrive/tasks/test_diff_traj.py b/torchdrive/tasks/test_diff_traj.py index 28e87dc..bde9779 100644 --- a/torchdrive/tasks/test_diff_traj.py +++ b/torchdrive/tasks/test_diff_traj.py @@ -1,12 +1,14 @@ -from torchdrive.tasks.diff_traj import XYEmbedding - import unittest import torch +from torchdrive.tasks.diff_traj import XEmbedding, XYEmbedding, XYLinearEmbedding + class TestDiffTraj(unittest.TestCase): - def test_diff_traj(self): - dim = 20 + def test_xy_embedding(self): + torch.manual_seed(0) + + dim = 32 traj = XYEmbedding( shape=(16, 24), @@ -14,13 +16,15 @@ def test_diff_traj(self): scale=1.0, ) - input = torch.tensor([ - (0.0, 0.0), - (1.0, 0.0), - (0.0, 1.0), - (-1.0, 0.0), - (0.0, -1.0), - ]).unsqueeze(0) + input = torch.tensor( + [ + (0.0, 0.0), + (1.0, 0.0), + (0.0, 1.0), + (-1.0, 0.0), + (0.0, -1.0), + ] + ).unsqueeze(0) output = traj(input) self.assertEqual(output.shape, (1, 5, dim)) @@ -28,3 +32,53 @@ def test_diff_traj(self): positions = traj.decode(output) self.assertEqual(positions.shape, (1, 5, 2)) torch.testing.assert_close(positions, input) + + def test_xy_linear_embedding(self): + torch.manual_seed(0) + + dim = 32 + + traj = XYLinearEmbedding( + shape=(16, 24), + dim=dim, + scale=1.0, + ) + + input = torch.tensor( + [ + (0.0, 0.0), + (1.0, 0.0), + (0.0, 1.0), + (-1.0, 0.0), + (0.0, -1.0), + ] + ).unsqueeze(0) + + output = traj(input) + self.assertEqual(output.shape, (1, 5, dim)) + + positions = traj.decode(output) + self.assertEqual(positions.shape, (1, 5, 2)) + torch.testing.assert_close(positions, input) + + def test_x_embedding(self): + torch.manual_seed(0) + + dim = 20 + + traj = XEmbedding(shape=16, dim=dim, scale=1.0) + + input = torch.tensor( + [ + 0.0, + -1.0, + 1.0, + ] + ).unsqueeze(0) + + output = traj(input) + self.assertEqual(output.shape, (1, 3, dim)) + + positions = traj.decode(output) + self.assertEqual(positions.shape, (1, 3)) + torch.testing.assert_close(positions, input)