Skip to content

Commit

Permalink
diff_traj: fix logging
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent 733215b commit 1d561c8
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torchdrive/tasks/diff_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ def forward(
cam_feats = cam_feats.unflatten(0, feats.shape[:2])

if writer is not None and log_img:
print(cam_feats.shape)
writer.add_image(
f"{cam}/pca",
render_color(cam_feats[0, 0]),
Expand Down Expand Up @@ -473,8 +472,16 @@ def forward(
num_elements = mask.float().sum()

if writer and log_text:
writer.add_scalar("paths/seq_len", pos_len)
writer.add_scalar("paths/num_elements", num_elements)
writer.add_scalar(
"paths/seq_len",
pos_len,
global_step=global_step,
)
writer.add_scalar(
"paths/num_elements",
num_elements,
global_step=global_step,
)

posmax = positions.abs().amax()
assert posmax < 1000, positions
Expand Down Expand Up @@ -509,7 +516,11 @@ def forward(

fig.legend()
plt.gca().set_aspect("equal")
writer.add_figure("paths/target", fig)
writer.add_figure(
"paths/target",
fig,
global_step=global_step,
)

losses_backward(losses)

Expand Down

0 comments on commit 1d561c8

Please sign in to comment.