Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SAM tests and use smaller checkpoints #23656

Merged
merged 3 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,9 @@ def test_retain_grad_hidden_states_attentions(self):
def test_hidden_states_output(self):
pass

def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# Use a slightly higher default tol to make the tests non-flaky
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)

@slow
def test_model_from_pretrained(self):
Expand All @@ -461,8 +462,8 @@ def prepare_dog_img():
@slow
class SamModelIntegrationTest(unittest.TestCase):
def test_inference_mask_generation_no_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -474,13 +475,12 @@ def test_inference_mask_generation_no_point(self):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1807, -3.4949, -3.4483]).to(torch_device), atol=2e-4))

def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -497,15 +497,14 @@ def test_inference_mask_generation_one_point_one_bb(self):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
torch.allclose(masks, torch.tensor([-12.7657, -12.3683, -12.5985]).to(torch_device), atol=2e-4)
)

def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -528,26 +527,26 @@ def test_inference_mask_generation_batched_points_batched_images(self):
EXPECTED_SCORES = torch.tensor(
[
[
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.6765, 0.9379, 0.8803],
[0.6765, 0.9379, 0.8803],
[0.6765, 0.9379, 0.8803],
[0.6765, 0.9379, 0.8803],
],
[
[0.8405, 0.6292, 0.3840],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.3317, 0.7264, 0.7646],
[0.6765, 0.9379, 0.8803],
[0.6765, 0.9379, 0.8803],
[0.6765, 0.9379, 0.8803],
],
]
)
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
EXPECTED_MASKS = torch.tensor([-2.8552, -2.7990, -2.9612])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))

def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -569,11 +568,11 @@ def test_inference_mask_generation_one_point_one_bb_zero(self):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9689), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7892), atol=1e-4))

def test_inference_mask_generation_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -590,8 +589,7 @@ def test_inference_mask_generation_one_point(self):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))

# With no label
input_points = [[[400, 650]]]
Expand All @@ -601,12 +599,11 @@ def test_inference_mask_generation_one_point(self):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))

def test_inference_mask_generation_two_points(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -623,8 +620,7 @@ def test_inference_mask_generation_two_points(self):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))

# no labels
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
Expand All @@ -633,11 +629,11 @@ def test_inference_mask_generation_two_points(self):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))

def test_inference_mask_generation_two_points_batched(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -654,13 +650,12 @@ def test_inference_mask_generation_two_points_batched(self):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9936), atol=1e-4))
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9716), atol=1e-4))
self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))

def test_inference_mask_generation_one_box(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -674,12 +669,11 @@ def test_inference_mask_generation_one_box(self):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()

self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8686), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))

def test_inference_mask_generation_batched_image_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand Down Expand Up @@ -707,8 +701,8 @@ def test_inference_mask_generation_batched_image_one_point(self):
self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))

def test_inference_mask_generation_two_points_point_batch(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -729,12 +723,12 @@ def test_inference_mask_generation_two_points_point_batch(self):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 2, 3))
torch.testing.assert_allclose(
iou_scores, torch.tensor([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), atol=1e-4, rtol=1e-4
iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
)

def test_inference_mask_generation_three_boxes_point_batch(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model.to(torch_device)
model.eval()
Expand All @@ -743,7 +737,9 @@ def test_inference_mask_generation_three_boxes_point_batch(self):

# fmt: off
input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu()
EXPECTED_IOU = torch.tensor([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]])
EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
[0.5996, 0.7661, 0.7937],
[0.5996, 0.7661, 0.7937]]])
# fmt: on
input_boxes = input_boxes.unsqueeze(0)

Expand Down
Loading