Skip to content

Commit

Permalink
diff_traj: use ae mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent 992c5be commit c120531
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 23 deletions.
10 changes: 10 additions & 0 deletions torchdrive/models/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def encode_one_hot(self, xy: torch.Tensor) -> torch.Tensor:
y = F.one_hot(y, self.num_buckets).float()
return torch.cat((x, y), dim=2).permute(0, 2, 1)

def forward(self, xy: torch.Tensor) -> torch.Tensor:
"""
Encodes the xy coordinates into one hot encoding.
Returns
-------
xy: [bs, 2, seq_len]
"""
return self.encode_one_hot(xy)

def decode(self, xy: torch.Tensor) -> torch.Tensor:
"""
Decodes from logit/probabilities one hot encoding.
Expand Down
119 changes: 97 additions & 22 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

import torch
import torch.nn.functional as F
from diffusers import DDPMScheduler
from diffusers import EulerDiscreteScheduler
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchdrive.amp import autocast
from torchdrive.autograd import autograd_context, register_log_grad_norm
from torchdrive.data import Batch
from torchdrive.losses import losses_backward
from torchdrive.models.mlp import MLP
from torchdrive.models.path import XYEncoder
from torchdrive.tasks.van import Van

from torchdrive.transforms.batch import NormalizeCarPosition
from torchtune.modules import RotaryPositionalEmbeddings
from torchworld.models.vit import MaskViT
Expand Down Expand Up @@ -267,6 +268,39 @@ def ae_loss(self, input: torch.Tensor) -> torch.Tensor:
return x + y


class XYMLPEncoder(nn.Module):
def __init__(self, dim: int, max_dist: float, dropout: float = 0.1) -> None:
super().__init__()

self.embedding = XYEncoder(num_buckets=dim // 2, max_dist=max_dist)
self.encoder = MLP(dim, dim, dim, num_layers=3, dropout=dropout)
self.decoder = MLP(dim, dim, dim, num_layers=3, dropout=dropout)

def forward(self, xy: torch.Tensor) -> torch.Tensor:
"""
Args:
xy: the list of positions (..., 2)
Returns:
the embedding of the position (..., dim)
"""
xy = xy.permute(0, 2, 1)
one_hot = self.embedding.encode_one_hot(xy)
return self.encoder(one_hot).permute(0, 2, 1)

def decode(self, input: torch.Tensor) -> torch.Tensor:
emb = self.decoder(input.permute(0, 2, 1))
xy = self.embedding.decode(emb).permute(0, 2, 1)
return xy

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
predicted = predicted.permute(0, 2, 1)
target = target.permute(0, 2, 1)
emb = self.decoder(predicted)
print(predicted.shape, target.shape, emb.shape)
return self.embedding.loss(emb, target)


class DiffTraj(nn.Module, Van):
"""
A diffusion model for trajectory detection.
Expand All @@ -283,6 +317,8 @@ def __init__(
num_heads: int = 16,
num_encode_frames: int = 1,
num_frames: int = 1,
num_inference_timesteps: int = 20,
num_train_timesteps: int = 1000,
):
super().__init__()

Expand All @@ -291,6 +327,9 @@ def __init__(
self.num_encode_frames = num_encode_frames
self.cam_shape = cam_shape
self.feat_shape = (cam_shape[0] // 16, cam_shape[1] // 16)
self.num_train_timesteps = num_train_timesteps
self.num_inference_timesteps = num_inference_timesteps

self.encoders = nn.ModuleDict(
{
cam: MaskViT(
Expand All @@ -303,12 +342,7 @@ def __init__(
)

# embedding
# 2*100m/512 == 0.39 meters
self.xy_embedding = XYLinearEmbedding(
shape=(512, 512),
scale=100, # 100 meters
dim=dim,
)
self.xy_embedding = XYMLPEncoder(dim=dim, max_dist=128)

self.denoiser = Denoiser(
max_seq_len=256,
Expand All @@ -325,7 +359,9 @@ def __init__(
nn.Linear(dim, dim),
)

self.noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
self.noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000)
self.eval_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000)
self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps)

self.batch_transform = NormalizeCarPosition(start_frame=0)

Expand Down Expand Up @@ -399,14 +435,26 @@ def forward(
with autocast():
# checkpoint encoders to save memory
encoder = self.encoders[cam]
cam_feats = torch.utils.checkpoint.checkpoint(
unmasked, cam_feats = torch.utils.checkpoint.checkpoint(
encoder,
feats.flatten(0, 1),
mask,
use_reentrant=False,
)
assert cam_feats.requires_grad, f"missing grad for cam {cam}"

if writer is not None and log_img:
writer.add_image(
f"{cam}/color",
normalize_img(feats[0, 0]),
global_step=global_step,
)
writer.add_image(
f"{cam}/pca",
render_pca(unmasked[0].permute(1, 2, 0)),
global_step=global_step,
)

if writer is not None and log_text:
register_log_grad_norm(
t=cam_feats,
Expand All @@ -419,13 +467,6 @@ def forward(
# (n, seq_len, hidden_dim) -> (bs, num_encode_frames, seq_len, hidden_dim)
cam_feats = cam_feats.unflatten(0, feats.shape[:2])

if writer is not None and log_img:
writer.add_image(
f"{cam}/pca",
render_color(cam_feats[0, 0]),
global_step=global_step,
)

# flatten time
# (bs, num_encode_frames, seq_len, hidden_dim) -> (bs, num_encode_frames * seq_len, hidden_dim)
cam_feats = cam_feats.flatten(1, 2)
Expand All @@ -447,7 +488,7 @@ def forward(
positions /= positions[..., -1:] + 1e-8 # perspective warp
positions = positions[..., :2]

losses["xy_embedding/ae"] = self.xy_embedding(positions)
# losses["xy_embedding/ae"] = self.xy_embedding(positions)

velocity = positions[:, 1] - positions[:, 0]
assert positions.size(-1) == 2
Expand Down Expand Up @@ -498,22 +539,58 @@ def forward(
device=traj_embed.device,
dtype=torch.int64,
)
traj_embed = self.noise_scheduler.add_noise(traj_embed, noise, timesteps)
traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps)

with autocast():
pred_noise = self.denoiser(traj_embed, input_tokens)
pred_noise = self.denoiser(traj_embed_noise, input_tokens)

noise_loss = F.mse_loss(pred_noise, noise, reduction="none")
noise_loss = noise_loss[mask]
losses["diffusion"] = noise_loss.mean()

print(traj_embed_noise.shape, positions.shape)
losses["ae/with_noise"] = self.xy_embedding.loss(traj_embed_noise, positions)
losses["ae/with_pos"] = self.xy_embedding.loss(traj_embed, positions)

losses_backward(losses)

if writer and log_img:
# calculate cross_attn_weights
with torch.no_grad():
fig = plt.figure()

# generate prediction
self.train()

self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
pred_traj = torch.randn_like(noise[:1])
for timestep in self.eval_noise_scheduler.timesteps:
with autocast():
pred_traj = self.noise_scheduler.scale_model_input(
pred_traj, timestep
)
noise = self.denoiser(pred_traj, input_tokens[:1])
pred_traj = self.noise_scheduler.step(
noise,
timestep,
pred_traj,
generator=torch.Generator(device=device).manual_seed(0),
).prev_sample

pred_positions = self.xy_embedding.decode(pred_traj)[0].cpu()
plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred")

target = positions[0].detach().cpu()
plt.plot(target[..., 0], target[..., 1], label="target")

writer.add_scalar(
"paths/pred_mae",
F.l1_loss(pred_positions, target).item(),
global_step=global_step,
)

self.eval()

fig.legend()
plt.gca().set_aspect("equal")
writer.add_figure(
Expand All @@ -522,6 +599,4 @@ def forward(
global_step=global_step,
)

losses_backward(losses)

return losses
32 changes: 32 additions & 0 deletions torchdrive/tasks/test_diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
XEmbedding,
XYEmbedding,
XYLinearEmbedding,
XYMLPEncoder,
)


Expand Down Expand Up @@ -121,8 +122,39 @@ def test_diff_traj(self):
num_layers=2,
num_heads=1,
cam_shape=(48, 64),
num_inference_timesteps=2,
)

batch = dummy_batch()
writer = MagicMock()
losses = m(batch, global_step=0, writer=writer)

def test_xy_mlp_encoder(self):
torch.manual_seed(0)

m = XYMLPEncoder(
dim=32,
max_dist=1.0,
)

input = torch.tensor(
[
(0.0, 0.0),
(1.0, 0.0),
(0.0, 1.0),
(-1.0, 0.0),
(0.0, -1.0),
]
).unsqueeze(0)

out = m(input)
self.assertEqual(out.shape, (1, 5, 32))

decoded = m.decode(out)
self.assertEqual(decoded.shape, (1, 5, 2))

loss = m.loss(out, input)
self.assertEqual(loss.shape, (1, 5))
loss.sum().backward()
for param in m.parameters():
self.assertIsNotNone(param.grad)
4 changes: 3 additions & 1 deletion torchworld/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

x = x + self.encoder.encoder.pos_embedding

unmasked = x

# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, mask.sum())
x = x[:, :, mask]

Expand All @@ -54,4 +56,4 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
x = self.encoder.encoder.ln(
self.encoder.encoder.layers(self.encoder.encoder.dropout(x))
)
return self.project(x)
return unmasked, self.project(x)

0 comments on commit c120531

Please sign in to comment.