Skip to content

Commit

Permalink
diff_traj: use per sequence mask
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent a929c9e commit db40733
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 22 deletions.
1 change: 1 addition & 0 deletions .torchxconfig
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ scheduler=local_cwd
[component:dist.ddp]
j=1x2
script=train.py
env=PYTHONBREAKPOINT=IPython.core.debugger.set_trace
97 changes: 79 additions & 18 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchdrive.amp import autocast
from torchdrive.autograd import autograd_context, register_log_grad_norm
from torchdrive.data import Batch
from torchdrive.debug import assert_not_nan, is_nan
from torchdrive.losses import losses_backward
from torchdrive.models.mlp import MLP
from torchdrive.models.path import XYEncoder
Expand All @@ -27,6 +28,37 @@
from torchworld.transforms.mask import random_block_mask, true_mask


def square_mask(mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Create a squared mask from a sequence mask.
Arguments:
mask: the sequence mask (bs, seq_len)
num_heads: the number of heads
Returns:
the squared mask (bs*num_heads, seq_len, seq_len)
"""

bs, seq_len = mask.shape

# (bs, seq_len) -> (bs, 1, seq_len)
x = mask.unsqueeze(1)
# (bs, 1, seq_len) -> (bs, seq_len, seq_len)
x = x.expand(-1, seq_len, seq_len)

# (bs, seq_len) -> (bs, seq_len, 1)
y = mask.unsqueeze(2)
# (bs, seq_len, 1) -> (bs, seq_len, seq_len)
y = y.expand(-1, seq_len, seq_len)

mask = torch.logical_and(x, y).repeat(num_heads, 1, 1)

diagonal = torch.arange(seq_len, device=mask.device)
mask[:, diagonal, diagonal] = True
return mask


class Denoiser(nn.Module):
"""Transformer denoising model for 1d sequences"""

Expand Down Expand Up @@ -58,14 +90,20 @@ def __init__(
)
self.layers = nn.Sequential(layers)

def forward(self, input: torch.Tensor, condition: torch.Tensor):
def forward(
self, input: torch.Tensor, input_mask: torch.Tensor, condition: torch.Tensor
):
torch._assert(
input_mask.dim() == 2,
f"Expected (batch_size, seq_length) got {input_mask.shape}",
)
torch._assert(
input.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}",
)
torch._assert(
condition.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}",
f"Expected (batch_size, seq_length, hidden_dim) got {condition.shape}",
)

x = input
Expand All @@ -76,8 +114,12 @@ def forward(self, input: torch.Tensor, condition: torch.Tensor):
x = self.positional_embedding(x)
x = x.flatten(-2, -1)

for layer in self.layers:
x = layer(x, condition)
attn_mask = square_mask(input_mask, num_heads=self.num_heads)
# True values are ignored so need to flip the mask
attn_mask = torch.logical_not(attn_mask)

for i, layer in enumerate(self.layers):
x = layer(tgt=x, tgt_mask=attn_mask, memory=condition)

return x

Expand Down Expand Up @@ -501,8 +543,6 @@ def forward(
positions /= positions[..., -1:] + 1e-8 # perspective warp
positions = positions[..., :2]

# losses["xy_embedding/ae"] = self.xy_embedding(positions)

# calculate velocity between first two frames to allow model to understand current speed
# TODO: convert this to a categorical embedding
velocity = positions[:, 1] - positions[:, 0]
Expand All @@ -512,14 +552,25 @@ def forward(
static_features = self.static_features_encoder(velocity).unsqueeze(1)

lengths = mask.sum(dim=-1)
pos_len = lengths.amax()
min_len = lengths.amin()
assert min_len > 0, "got example with zero sequence length"

# truncate to shortest sequence
# pos_len = lengths.amin()
# if pos_len % align != 0:
# pos_len -= pos_len % align
# assert pos_len >= 8
# positions = positions[:, :pos_len]
# mask = mask[:, :pos_len]

# we need to be aligned to size 8
# pad length
align = 8
if positions.size(1) % align != 0:
pad = align - positions.size(1) % align
mask = F.pad(mask, (0, pad), value=True)
positions = F.pad(positions, (0, 0, 0, pad), value=0)
pos_len = positions.size(1)

assert positions.size(1) % align == 0
assert mask.size(1) % align == 0
Expand All @@ -529,7 +580,7 @@ def forward(

if writer and log_text:
writer.add_scalar(
"paths/seq_len",
"paths/pos_len",
pos_len,
global_step=global_step,
)
Expand Down Expand Up @@ -558,16 +609,16 @@ def forward(
# add static feature info to all condition keys to avoid noise
input_tokens = input_tokens + static_features

pred_noise = self.denoiser(traj_embed_noise, input_tokens)
pred_noise = self.denoiser(traj_embed_noise, mask, input_tokens)

noise_loss = F.mse_loss(pred_noise, noise, reduction="none")
noise_loss = noise_loss[mask]
losses["diffusion"] = noise_loss.mean()

losses["ae/with_noise"] = self.xy_embedding.loss(
traj_embed_noise, positions
).mean()
losses["ae/ae"] = self.xy_embedding.loss(traj_embed, positions).mean()
losses["ae/with_noise"] = self.xy_embedding.loss(traj_embed_noise, positions)[
mask
].mean()
losses["ae/ae"] = self.xy_embedding.loss(traj_embed, positions)[mask].mean()

losses_backward(losses)

Expand All @@ -579,42 +630,52 @@ def forward(
# generate prediction
self.train()

pred_len = mask[0].sum()

pred_traj = torch.randn_like(noise[:1])
self.eval_noise_scheduler.set_timesteps(self.num_inference_timesteps)
for timestep in self.eval_noise_scheduler.timesteps:
with autocast():
pred_traj = self.eval_noise_scheduler.scale_model_input(
pred_traj, timestep
)
noise = self.denoiser(pred_traj, input_tokens[:1])
noise = self.denoiser(pred_traj, mask[:1], input_tokens[:1])
pred_traj = self.eval_noise_scheduler.step(
noise,
timestep,
pred_traj,
generator=torch.Generator(device=device).manual_seed(0),
).prev_sample

pred_positions = self.xy_embedding.decode(pred_traj)[0].cpu()
pred_positions = self.xy_embedding.decode(pred_traj)[0, :pred_len].cpu()
plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="pred")

noise_positions = self.xy_embedding.decode(traj_embed_noise[:1])[
0
0,
:pred_len,
].cpu()
plt.plot(
noise_positions[..., 0], noise_positions[..., 1], label="with_noise"
)

pos_positions = self.xy_embedding.decode(traj_embed[:1])[0].cpu()
pos_positions = self.xy_embedding.decode(traj_embed[:1])[
0, :pred_len
].cpu()
plt.plot(pos_positions[..., 0], noise_positions[..., 1], label="ae")

target = positions[0].detach().cpu()
target = positions[0, :pred_len].detach().cpu()
plt.plot(target[..., 0], target[..., 1], label="target")

writer.add_scalar(
"paths/pred_mae",
F.l1_loss(pred_positions, target).item(),
global_step=global_step,
)
writer.add_scalar(
"paths/pred_len",
pred_len,
global_step=global_step,
)

self.eval()

Expand Down
19 changes: 18 additions & 1 deletion torchdrive/tasks/test_diff_traj.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import random
import unittest
from unittest.mock import MagicMock, patch

import torch
from torchdrive.data import Batch, dummy_batch

from torchdrive.tasks.diff_traj import (
DiffTraj,
square_mask,
XEmbedding,
XYEmbedding,
XYLinearEmbedding,
Expand Down Expand Up @@ -114,6 +115,7 @@ def test_x_embedding(self):

def test_diff_traj(self):
torch.manual_seed(0)
random.seed(0)

m = DiffTraj(
cameras=["left"],
Expand Down Expand Up @@ -158,3 +160,18 @@ def test_xy_mlp_encoder(self):
loss.sum().backward()
for param in m.parameters():
self.assertIsNotNone(param.grad)

def test_square_mask(self):
input = torch.tensor(
[
[True, True],
[True, False],
]
)
target = torch.tensor(
[[[True, True], [True, True]], [[True, False], [False, True]]]
)

output = square_mask(input, num_heads=3)
self.assertEqual(output.shape, (6, 2, 2))
torch.testing.assert_close(output[:2], target)
2 changes: 0 additions & 2 deletions torchworld/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

unmasked = x

print(x.shape, mask.shape)

# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, mask.sum())
x = x[:, :, mask]

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
device = torch.device("cpu")

torch.set_float32_matmul_precision("high")
sdpa_kernel(SDPBackend.FLASH_ATTENTION).__enter__() # force flash attention
# sdpa_kernel(SDPBackend.FLASH_ATTENTION).__enter__() # force flash attention

BS: int = config.batch_size
NUM_EPOCHS: int = config.epochs
Expand Down

0 comments on commit db40733

Please sign in to comment.