Skip to content

Commit

Permalink
Update test_tracking.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HaozheQi committed Mar 9, 2020
1 parent ccac70f commit 9015b91
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test(loader,model,epoch=-1,shape_aggregation="",reference_BB="",model_fusion
offset=dataset.offset_BB,
scale=dataset.scale_BB)

candidate_PCs,candidate_labels,candidate_reg = utils.regularizePCwithlabel(candidate_PC, candidate_label,candidate_reg,dataset.input_size)
candidate_PCs,candidate_labels,candidate_reg = utils.regularizePCwithlabel(candidate_PC, candidate_label,candidate_reg,dataset.input_size,istrain=False)

candidate_PCs_torch = candidate_PCs.unsqueeze(0).cuda()

Expand All @@ -96,7 +96,7 @@ def test(loader,model,epoch=-1,shape_aggregation="",reference_BB="",model_fusion
else:
model_PC = utils.getModel(PCs[:i],results_BBs,offset=dataset.offset_BB,scale=dataset.scale_BB)

model_PC_torch = utils.regularizePC(model_PC, dataset.input_size).unsqueeze(0)
model_PC_torch = utils.regularizePC(model_PC, dataset.input_size,istrain=False).unsqueeze(0)
model_PC_torch = Variable(model_PC_torch, requires_grad=False).cuda()
candidate_PCs_torch = Variable(candidate_PCs_torch, requires_grad=False).cuda()

Expand Down

0 comments on commit 9015b91

Please sign in to comment.