Skip to content

Commit

Permalink
diff_traj: vista is training
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 23, 2024
1 parent e798e57 commit ba40295
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 6 deletions.
2 changes: 1 addition & 1 deletion configs/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
autolabel_path=None, # "/mnt/ext3/autolabel2",
mask_path="n/a", # only used for rice dataset
num_workers=16,
batch_size=64,
batch_size=4,
autolabel=False,
)
110 changes: 110 additions & 0 deletions notebooks/compute_dream_pos.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions torchdrive/models/vista.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
from typing import Tuple
import time

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -51,10 +52,14 @@ def __init__(
self.num_frames = num_frames
self.render_size = render_size

start = time.perf_counter()

config = OmegaConf.load(config_path)
model = load_model_from_config(config, ckpt_path)
self.model = model.bfloat16().to(device).eval()

print(f"loaded vista in {time.perf_counter() - start:.2f}s")

guider = "VanillaCFG"
self.sampler = init_sampling(
guider=guider,
Expand Down
136 changes: 131 additions & 5 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
render_color,
render_pca,
)
from torchworld.transforms.transform3d import Transform3d
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_euler_angles
from torchworld.transforms.mask import random_block_mask, true_mask
from torchdrive.models.vista import VistaSampler


def square_mask(mask: torch.Tensor, num_heads: int) -> torch.Tensor:
Expand Down Expand Up @@ -746,6 +749,8 @@ def __init__(

self.model = ConvNextPathPred()

self.vista = VistaSampler()

"""
self.encoders = nn.ModuleDict(
{
Expand Down Expand Up @@ -791,10 +796,10 @@ def __init__(

self.batch_transform = Compose(
NormalizeCarPosition(start_frame=0),
ImageTransform(
v2.RandomRotation(15, InterpolationMode.BILINEAR),
v2.RandomErasing(),
),
#ImageTransform(
# v2.RandomRotation(15, InterpolationMode.BILINEAR),
# v2.RandomErasing(),
#),
)

def param_opts(self, lr: float) -> List[Dict[str, object]]:
Expand Down Expand Up @@ -839,7 +844,7 @@ def param_opts(self, lr: float) -> List[Dict[str, object]]:
]

def should_log(self, global_step: int, BS: int) -> Tuple[bool, bool]:
log_text_interval = 1000 // BS
log_text_interval = 10 // BS
# log_text_interval = 1
# It's important to scale the less frequent interval off the more
# frequent one to avoid divisor issues.
Expand Down Expand Up @@ -1066,6 +1071,39 @@ def forward(

pred_len = min(pred_traj.size(1), mask[0].sum().item())

dreamed_imgs = []
for i in range(BS):
cond_img = batch.color[cam][i:i+1, 0]
cond_traj = pred_traj[i:i+1]

dreamed_img = self.vista.generate(cond_img, cond_traj)
# add last img (frame 10 == 1s)
dreamed_imgs.append(dreamed_img[-1])

# [BS, 1, 3, H, W]
dream_img = torch.stack(dreamed_imgs, dim=0).unsqueeze(1)

if log_img:
ctx.add_image(
f"{cam}/dream",
normalize_img(dream_img[0, 0]),
)

pred_traj_len = min(positions.size(1), pred_traj.size(1))
dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos(
positions[:, :pred_traj_len],
mask[:, :pred_traj_len],
pred_traj[:, :pred_traj_len],
step=2,
)

dream_losses, dream_traj, all_dream_traj = self.model(
velocity, dream_img, dream_target, dream_mask
)
for k, v in dream_losses.items():
losses[f"dream-{k}"] = v


# noise_loss, noise_traj = self.y_embedding.loss(traj_embed_noise, positions)
# losses["ae/with_noise"] = (
# noise_loss.mean() * 0.01
Expand Down Expand Up @@ -1158,4 +1196,92 @@ def forward(
fig,
)

with torch.no_grad():
fig = plt.figure()

pred_len = min(pred_len, dream_mask[0].sum().item())

og_target = dream_positions[0, :pred_len].detach().cpu()
plt.plot(og_target[..., 0], og_target[..., 1], label="positions")

target = dream_target[0, :pred_len].detach().cpu()
plt.plot(target[..., 0], target[..., 1], label="dream_target")

pred_positions = dream_pred[0, :pred_len].cpu()
plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="og_pred")

pred_positions = dream_traj[0, :pred_len].cpu()
plt.plot(pred_positions[..., 0], pred_positions[..., 1], label="new_pred")

fig.legend()
plt.gca().set_aspect("equal")

ctx.add_figure(
"dream-paths/target",
fig,
)

return losses

def compute_dream_pos(positions: torch.Tensor, mask: torch.Tensor, pred_traj: torch.Tensor, step: int=2) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute a new ground truth trajectory for the dreamer to use as a loss.
Outputted directory is centered at 0,0 and uses the new direction.
Args:
positions: (B, T, 2) the ground truth trajectory
mask: (B, T) mask for the ground truth trajectory
pred_traj: (B, T, 2) the trajectory the dreamer is following
step: pred_traj[:, step] is the new root position
Returns:
dream_target: (B, T-step, 2) the new ground truth trajectory in step coordinate frame
dream_mask: (B, T-step) the new mask for the ground truth trajectory
dream_positions: (B, T-step, 2) positions in step coordinate frame
dream_pred: (B, T-step, 2) pred_traj in step coordinate frame
"""
direction = pred_traj[:, step] - pred_traj[:, step-1]

angle = torch.atan2(direction[:, 1], direction[:, 0])
rot = torch.stack([
torch.stack([
torch.cos(angle),
-torch.sin(angle),
], dim=-1),
torch.stack([
torch.sin(angle),
torch.cos(angle),
], dim=-1)
], dim=-1)
rot = rot.pinverse()

# drop old points
positions = positions[:, step:]
mask = mask[:, step:]
pred_traj = pred_traj[:, step:]

# use linear interpolation between pred_traj and positions
#factor = torch.arange(0, positions.size(1), device=positions.device) / (positions.size(1) - 1)
#factor = factor.unsqueeze(0).unsqueeze(-1)
#dream_pos = pred_traj * (1-factor) + positions * factor

# use ema interpolation between pred_traj and positions
factor = torch.full((positions.size(1),), 0.5, device=positions.device)
factor[0] = 1.0
factor = torch.cumprod(factor, dim=0)
factor = factor.unsqueeze(0).unsqueeze(-1)

dream_pos = pred_traj * factor + positions * (1-factor)

origin = dream_pos[:, 0:1]

# center dream_pos according to direction
dream_pos = dream_pos - origin
pred_traj = pred_traj - origin
positions = positions - origin

# reorientate
dream_pos = dream_pos.matmul(rot)
pred_traj = pred_traj.matmul(rot)
positions = positions.matmul(rot)

return dream_pos, mask, positions, pred_traj
13 changes: 13 additions & 0 deletions torchdrive/tasks/test_diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
XYLinearEmbedding,
XYMLPEncoder,
XYSineMLPEncoder,
compute_dream_pos,
)


Expand Down Expand Up @@ -215,3 +216,15 @@ def test_square_mask(self):
output = square_mask(input, num_heads=3)
self.assertEqual(output.shape, (6, 2, 2))
torch.testing.assert_close(output[:2], target)

def test_compute_dream_pos(self):

positions = torch.rand(2, 18, 2)
mask = torch.ones(2, 18)
pred_traj = torch.rand(2, 18, 2)

dream_target, dream_mask, dream_positions, dream_pred = compute_dream_pos(positions, mask, pred_traj)
self.assertEqual(dream_target.shape, (2, 16, 2))
self.assertEqual(dream_mask.shape, (2, 16))
self.assertEqual(dream_positions.shape, (2, 16, 2))
self.assertEqual(dream_pred.shape, (2, 16, 2))

0 comments on commit ba40295

Please sign in to comment.