Skip to content

Commit

Permalink
[SAM] Fixes pipeline and adds a dummy pipeline test (huggingface#23684
Browse files Browse the repository at this point in the history
)

* add a dummy pipeline test

* change test name
  • Loading branch information
younesbelkada authored and novice03 committed Jun 23, 2023
1 parent 4671310 commit d596555
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/sam/image_processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ def _generate_crop_boxes(
cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)

crop_boxes = np.array(crop_boxes)
crop_boxes = crop_boxes.astype(np.float32)
points_per_crop = np.array([point_grid_per_crop])
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
Expand Down
8 changes: 7 additions & 1 deletion tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import requests

from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available

Expand Down Expand Up @@ -751,3 +751,9 @@ def test_inference_mask_generation_three_boxes_point_batch(self):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 3, 3))
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)

def test_dummy_pipeline_generation(self):
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
raw_image = prepare_image()

_ = generator(raw_image, points_per_batch=64)

0 comments on commit d596555

Please sign in to comment.