Skip to content

Commit

Permalink
torchdrive/path: add final label
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 29, 2023
1 parent 8e92054 commit 2ebd964
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torchdrive/tasks/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(
length = lengths[0] - 1
plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target")
plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "ro", label="final")

for i, predicted in enumerate(all_predicted):
if i % max(1, self.num_ar_iters // 4) != 0:
Expand All @@ -148,16 +148,20 @@ def forward(
autoregressive = PathTransformer.infer(
self.transformer,
bev[:1],
positions[:1, ..., :2],
positions[:1, ..., :1],
final_pos[:1],
n=length - 2,
n=length - 1,
)
assert autoregressive.shape == (1, 3, length), (
autoregressive.shape,
length,
)
plt.plot(*target[0, 0:2, :length].detach().cpu(), label="target")
plt.plot(
*autoregressive[0, 0:2, 1:].detach().cpu(), label="autoregressive"
)
plt.plot(*prev[0, 0:2, 0].detach().cpu(), "go", label="origin")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "go", label="final")
plt.plot(*final_pos[0, 0:2].detach().cpu(), "ro", label="final")
self.train()

fig.legend()
Expand Down

0 comments on commit 2ebd964

Please sign in to comment.