Skip to content

Commit

Permalink
diff_traj: adjust LR and add static features
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent df20a77 commit 2accf68
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __init__(
num_heads: int = 16,
num_encode_frames: int = 1,
num_frames: int = 1,
num_inference_timesteps: int = 20,
num_inference_timesteps: int = 50,
num_train_timesteps: int = 1000,
):
super().__init__()
Expand Down Expand Up @@ -372,12 +372,12 @@ def param_opts(self, lr: float) -> List[Dict[str, object]]:
{
"name": "encoders",
"params": list(self.encoders.parameters()),
"lr": lr / len(self.encoders),
"lr": lr / 10,
},
{
"name": "static_features",
"params": list(self.static_features_encoder.parameters()),
"lr": lr,
"lr": lr / 10,
},
{
"name": "denoiser",
Expand All @@ -387,7 +387,7 @@ def param_opts(self, lr: float) -> List[Dict[str, object]]:
{
"name": "xy_embedding",
"params": list(self.xy_embedding.parameters()),
"lr": lr,
"lr": lr / 10,
},
]

Expand Down Expand Up @@ -503,11 +503,13 @@ def forward(

# losses["xy_embedding/ae"] = self.xy_embedding(positions)

# calculate velocity between first two frames to allow model to understand current speed
# TODO: convert this to a categorical embedding
velocity = positions[:, 1] - positions[:, 0]
assert positions.size(-1) == 2
velocity = torch.linalg.vector_norm(velocity, dim=-1, keepdim=True)

static_features = self.static_features_encoder(velocity)
static_features = self.static_features_encoder(velocity).unsqueeze(1)

lengths = mask.sum(dim=-1)
pos_len = lengths.amax()
Expand Down Expand Up @@ -553,6 +555,9 @@ def forward(
traj_embed_noise = self.noise_scheduler.add_noise(traj_embed, noise, timesteps)

with autocast():
# add static feature info to all condition keys to avoid noise
input_tokens = input_tokens + static_features

pred_noise = self.denoiser(traj_embed_noise, input_tokens)

noise_loss = F.mse_loss(pred_noise, noise, reduction="none")
Expand Down

0 comments on commit 2accf68

Please sign in to comment.