Skip to content

Commit

Permalink
Fix PyTorch SAM tests (#23682)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored May 23, 2023
1 parent b687af0 commit abf691a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_inference_mask_generation_no_point(self):
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1807, -3.4949, -3.4483]).to(torch_device), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))

def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")
Expand All @@ -499,7 +499,7 @@ def test_inference_mask_generation_one_point_one_bb(self):
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-12.7657, -12.3683, -12.5985]).to(torch_device), atol=2e-4)
torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
)

def test_inference_mask_generation_batched_points_batched_images(self):
Expand Down Expand Up @@ -540,7 +540,7 @@ def test_inference_mask_generation_batched_points_batched_images(self):
],
]
)
EXPECTED_MASKS = torch.tensor([-2.8552, -2.7990, -2.9612])
EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))

Expand Down Expand Up @@ -568,7 +568,7 @@ def test_inference_mask_generation_one_point_one_bb_zero(self):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7892), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))

def test_inference_mask_generation_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")
Expand Down

0 comments on commit abf691a

Please sign in to comment.