diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 6f4aa166715..dd650486e30 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -16,7 +16,7 @@ import torch import torchvision.transforms.v2 as tvt_v2 from datumaro import Polygon as dmPolygon -from torch import LongTensor, Tensor, nn +from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torchvision import tv_tensors from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor @@ -764,21 +764,38 @@ def _customize_outputs( # type: ignore[override] masks: list[Mask] = [] prompts: list[Points] = [] scores: list[Tensor] = [] - labels: list[LongTensor] = [] - for predicted_masks, used_points in outputs: + labels: list[Tensor] = [] + for idx, (predicted_masks, used_points) in enumerate(outputs): + _masks: list[Tensor] = [] + _prompts: list[Tensor] = [] + _scores: list[Tensor] = [] + _labels: list[Tensor] = [] for label, predicted_mask in predicted_masks.items(): if len(predicted_mask) == 0: continue - masks.append(Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32)) - prompts.append( - Points( - torch.stack([p[:2] for p in used_points[label]], dim=0), - canvas_size=inputs.imgs_info[0].ori_shape, - dtype=torch.float32, + _masks.append(torch.stack(predicted_mask, dim=0)) + _used_points_scores = torch.stack(used_points[label], dim=0) + _prompts.append(_used_points_scores[:, :2]) + _scores.append(_used_points_scores[:, 2]) + _labels.append(torch.tensor([label] * len(_used_points_scores), dtype=torch.int64, device=self.device)) + + if len(_masks) == 0: + masks.append( + tv_tensors.Mask( + torch.zeros((1, *inputs.imgs_info[idx].ori_shape), dtype=torch.float32, device=self.device), ), ) - scores.append(torch.stack([p[2] for p in used_points[label]], dim=0)) - labels.append(torch.cat([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0)) + prompts.append( + Points([], canvas_size=inputs.imgs_info[idx].ori_shape, dtype=torch.float32, device=self.device), + ) + scores.append(torch.tensor([-1.0], dtype=torch.float32, device=self.device)) + labels.append(torch.tensor([-1], dtype=torch.int64, device=self.device)) + continue + + masks.append(tv_tensors.Mask(torch.cat(_masks, dim=0))) + prompts.append(Points(torch.cat(_prompts, dim=0), canvas_size=inputs.imgs_info[idx].ori_shape)) + scores.append(torch.cat(_scores, dim=0)) + labels.append(torch.cat(_labels, dim=0)) return ZeroShotVisualPromptingBatchPredEntity( batch_size=len(outputs), @@ -926,15 +943,19 @@ def load_reference_info( log.info(f"reference info saved at {path_to_directly_load} was successfully loaded.") else: - _infer_reference_info_root: Path = ( - self.infer_reference_info_root - if self.infer_reference_info_root == self.infer_reference_info_root.absolute() - else Path(default_root_dir) / self.infer_reference_info_root - ) + if str(self.infer_reference_info_root) == "../.latest/train": + # for default setting + path_reference_info = ( + Path(default_root_dir) + / self.infer_reference_info_root + / self.reference_info_dir + / "reference_info.pt" + ) + else: + # for user input + path_reference_info = self.infer_reference_info_root / self.reference_info_dir / "reference_info.pt" - if ( - path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt" - ).is_file(): + if path_reference_info.is_file(): reference_info = torch.load(path_reference_info) retval = True log.info(f"reference info saved at {path_reference_info} was successfully loaded.") diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 7e69fd4ec3b..3db822fb46a 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + from model_api.models.utils import PredictedMask, VisualPromptingResult from torchmetrics import MetricCollection from otx.core.data.module import OTXDataModule @@ -139,23 +140,6 @@ def _inference_step_for_zero_shot( for ett in converted_entities["preds"] ] _target = converted_entities["target"] - - # match #_preds and #_target - if len(_preds) > len(_target): - # interpolate _target - num_diff = len(_preds) - len(_target) - for idx in range(num_diff): - _target.append(_target[idx]) - elif len(_preds) < len(_target): - num_diff = len(_target) - len(_preds) - pad_prediction = { - "masks": torch.zeros_like(_target[0]["masks"], dtype=_target[0]["masks"].dtype), - "labels": torch.zeros_like(_target[0]["labels"], dtype=_target[0]["labels"].dtype), - "scores": torch.zeros(len(_target[0]["labels"]), dtype=torch.float32), - } # for empty prediction - for idx in range(num_diff): - _preds.append(_preds[idx] if idx < len(_preds) else pad_prediction) - _metric.update(preds=_preds, target=_target) elif _name in ["iou", "f1-score", "dice"]: # BinaryJaccardIndex, BinaryF1Score, Dice @@ -345,6 +329,7 @@ def on_test_start(self) -> None: if not self.load_reference_info(self.trainer.default_root_dir, self.device): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False + self.training = True self.trainer.fit_loop.run() # to use infer logic self.training = False @@ -362,6 +347,7 @@ def on_predict_start(self) -> None: if not self.load_reference_info(self.trainer.default_root_dir, self.device): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False + self.training = True self.trainer.fit_loop.run() # to use infer logic self.training = False @@ -530,7 +516,7 @@ def forward( ) images, batch_prompts = self._customize_inputs(inputs) - outputs: list[Any] = [] + outputs: list[VisualPromptingResult] = [] for image, prompt in zip(images, batch_prompts): outputs.append(self.model(image, **prompt)) @@ -576,31 +562,28 @@ def _customize_inputs( # type: ignore[override] def _customize_outputs( self, - outputs: Any, # noqa: ANN401 + outputs: list[VisualPromptingResult], inputs: VisualPromptingBatchDataEntity, # type: ignore[override] ) -> VisualPromptingBatchPredEntity: """Customize OTX output batch data entity if needed for model.""" masks: list[tv_tensors.Mask] = [] - scores: list[torch.Tensor] = [] + scores: list[Tensor] = [] + labels: list[Tensor] = [] for image_output in outputs: - masks.extend( - [ - torch.as_tensor(hard_prediction, device=self.device) - for hard_prediction in image_output.hard_predictions - ], - ) - scores.extend([torch.as_tensor(score, device=self.device) for score in image_output.scores]) + masks.append(tv_tensors.Mask(np.concatenate(image_output.hard_predictions), device=self.device)) + scores.append(torch.as_tensor(np.concatenate(image_output.scores), device=self.device)) + labels.append(torch.as_tensor(image_output.labels, device=self.device)) return VisualPromptingBatchPredEntity( batch_size=len(outputs), images=inputs.images, imgs_info=inputs.imgs_info, - scores=[torch.cat(scores, dim=0)], - masks=[tv_tensors.Mask(torch.cat(masks, dim=0))], + scores=scores, + masks=masks, polygons=[], points=[], bboxes=[], - labels=[torch.cat(list(labels.values())) for labels in inputs.labels], + labels=labels, ) def optimize( # type: ignore[override] @@ -884,11 +867,11 @@ def infer( inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: np.ndarray, used_indices: np.ndarray, - ) -> list[list[defaultdict[int, list]]]: + ) -> list[dict[int, PredictedMask]]: """`Infer` for target predictions.""" images, _ = self._customize_inputs(inputs) - total_results: list[list[defaultdict[int, list]]] = [] + total_results: list[dict[int, PredictedMask]] = [] for image in images: result = self.model(image, VisualPromptingFeatures(reference_feats, used_indices)) total_results.append(result.data) @@ -898,7 +881,7 @@ def infer( def forward( # type: ignore[override] self, inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity: + ) -> tuple[dict[str, np.ndarray], list[np.ndarray]] | ZeroShotVisualPromptingBatchPredEntity: """Model forward function.""" kwargs: dict[str, Any] = {} fn = self.learn if self.training else self.infer @@ -959,42 +942,57 @@ def _customize_inputs( # type: ignore[override] def _customize_outputs( # type: ignore[override] self, - outputs: Any, # noqa: ANN401 + outputs: tuple[dict[str, np.ndarray], list[np.ndarray]] | list[dict[int, PredictedMask]], inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity: + ) -> tuple[dict[str, np.ndarray], list[np.ndarray]] | ZeroShotVisualPromptingBatchPredEntity: """Customize OTX output batch data entity if needed for model.""" - if self.training: + if self.training and isinstance(outputs, tuple): return outputs masks: list[tv_tensors.Mask] = [] prompts: list[Points] = [] - scores: list[torch.Tensor] = [] - labels: list[torch.LongTensor] = [] - for output in outputs: + scores: list[Tensor] = [] + labels: list[Tensor] = [] + for idx, output in enumerate(outputs): + if not isinstance(output, dict): + continue + _masks: list[np.ndarray] = [] + _prompts: list[np.ndarray] = [] + _scores: list[np.ndarray] = [] + _labels: list[np.ndarray] = [] for label, predicted_mask in output.items(): if len(predicted_mask.mask) == 0: continue + _masks.append(np.stack(predicted_mask.mask, axis=0)) + _used_points_scores = np.stack(predicted_mask.points, axis=0) + _prompts.append(_used_points_scores[:, :2]) + _scores.append(_used_points_scores[:, 2]) + _labels.append(np.array([label] * len(_used_points_scores))) + + if len(_masks) == 0: masks.append( tv_tensors.Mask( - torch.stack([torch.as_tensor(m) for m in predicted_mask.mask], dim=0), - dtype=torch.float32, - device=self.device, + torch.zeros((1, *inputs.imgs_info[idx].ori_shape), dtype=torch.float32, device=self.device), ), ) prompts.append( - Points( - torch.stack([torch.as_tensor(p[:2]) for p in predicted_mask.points], dim=0), - canvas_size=inputs.imgs_info[0].ori_shape, - dtype=torch.float32, - device=self.device, - ), - ) - scores.append( - torch.stack([torch.as_tensor(p[2]) for p in predicted_mask.points], dim=0).to(self.device), - ) - labels.append( - torch.cat([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device), + Points([], canvas_size=inputs.imgs_info[idx].ori_shape, dtype=torch.float32, device=self.device), ) + scores.append(torch.tensor([-1.0], dtype=torch.float32, device=self.device)) + labels.append(torch.tensor([-1], dtype=torch.int64, device=self.device)) + continue + + masks.append(tv_tensors.Mask(np.concatenate(_masks, axis=0), dtype=torch.float32, device=self.device)) + prompts.append( + Points( + np.concatenate(_prompts, axis=0), + canvas_size=inputs.imgs_info[idx].ori_shape, + dtype=torch.float32, + device=self.device, + ), + ) + scores.append(torch.as_tensor(np.concatenate(_scores, axis=0), dtype=torch.float32, device=self.device)) + labels.append(torch.as_tensor(np.concatenate(_labels, axis=0), dtype=torch.int64, device=self.device)) return ZeroShotVisualPromptingBatchPredEntity( batch_size=len(outputs), @@ -1193,15 +1191,19 @@ def _load_and_assign_reference_info(path: Path) -> bool: # if `path_to_directly_load` is given, forcely load return _load_and_assign_reference_info(path_to_directly_load) - _infer_reference_info_root: Path = ( - self.infer_reference_info_root - if self.infer_reference_info_root == self.infer_reference_info_root.absolute() - else Path(default_root_dir) / self.infer_reference_info_root - ) + if str(self.infer_reference_info_root) == "../.latest/train": + # for default setting + path_reference_info = ( + Path(default_root_dir) + / self.infer_reference_info_root + / self.reference_info_dir + / "reference_info.pickle" + ) + else: + # for user input + path_reference_info = self.infer_reference_info_root / self.reference_info_dir / "reference_info.pickle" - if ( - path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle" - ).is_file(): + if path_reference_info.is_file(): return _load_and_assign_reference_info(path_reference_info) return False @@ -1222,6 +1224,7 @@ def on_test_start(self) -> None: if not self.load_reference_info(self.trainer.default_root_dir): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False + self.training = True self.trainer.fit_loop.run() # to use infer logic self.training = False @@ -1238,6 +1241,7 @@ def on_predict_start(self) -> None: if not self.load_reference_info(self.trainer.default_root_dir): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False + self.training = True self.trainer.fit_loop.run() # to use infer logic self.training = False diff --git a/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml b/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml index 67b8b60cbeb..777dde2a345 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml @@ -7,7 +7,7 @@ model: async_inference: False use_throughput_mode: True reference_info_dir: reference_infos - infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location + infer_reference_info_root: ../.latest/train save_outputs: True engine: diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml index 187a3b15ea8..d93ed540a7e 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml @@ -10,7 +10,7 @@ model: default_threshold_target: 0.65 save_outputs: True reference_info_dir: reference_infos - infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location + infer_reference_info_root: ../.latest/train # options use_stability_score: False return_single_mask: False diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml index b7ca30b8708..64a84003374 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml @@ -10,7 +10,7 @@ model: default_threshold_target: 0.65 save_outputs: True reference_info_dir: reference_infos - infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location + infer_reference_info_root: ../.latest/train # options use_stability_score: False return_single_mask: False diff --git a/tests/integration/api/test_auto_configuration.py b/tests/integration/api/test_auto_configuration.py index dc9d87f8ad5..72f01547b2c 100644 --- a/tests/integration/api/test_auto_configuration.py +++ b/tests/integration/api/test_auto_configuration.py @@ -44,7 +44,7 @@ def test_auto_configuration( device=fxt_accelerator, ) if task.lower() == "zero_shot_visual_prompting": - engine.model.infer_reference_info_root = Path() + engine.model.infer_reference_info_root = Path(tmp_path_train) # update litmodule.hparams to reflect changed hparams engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index d6ff3f89932..ae303b079fb 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -48,7 +48,7 @@ def test_engine_from_config( device=fxt_accelerator, ) if task.lower() == "zero_shot_visual_prompting": - engine.model.infer_reference_info_root = Path() + engine.model.infer_reference_info_root = Path(tmp_path_train) # update litmodule.hparams to reflect changed hparams engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) @@ -97,7 +97,7 @@ def test_engine_from_config( model_name=str(exported_model_path["decoder"]), label_info=engine.datamodule.label_info, ) - engine.model.infer_reference_info_root = Path() + engine.model.infer_reference_info_root = Path(tmp_path_train) # update litmodule.hparams to reflect changed hparams engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) test_metric_from_ov_model = engine.test(checkpoint=exported_model_path["decoder"], accelerator="cpu") diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 50ad5f3eb4c..93f169cfdd8 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -5,6 +5,7 @@ from __future__ import annotations +from copy import deepcopy from pathlib import Path from unittest.mock import Mock @@ -12,6 +13,7 @@ import pytest import torch from model_api.models import SAMLearnableVisualPrompter, SAMVisualPrompter +from model_api.models.utils import PredictedMask from otx.core.data.entity.base import Points from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchPredEntity, @@ -500,32 +502,58 @@ def test_customize_outputs_training( ) -> None: ov_zero_shot_visual_prompting_model.training = True - outputs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] + outputs = ({"foo": np.array(1), "bar": np.array(2)}, [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]) result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) assert result == outputs + @pytest.mark.parametrize( + "outputs", + [ + [ + { + 1: PredictedMask(mask=[[1, 2, 3], [4, 5, 6]], points=[[13, 14, 15], [16, 17, 18]]), + 2: PredictedMask(mask=[[7, 8, 9], [10, 11, 12]], points=[[19, 20, 21], [22, 23, 24]]), + }, + ], + [ + { + 1: PredictedMask(mask=[], points=[]), + }, + ], + [ + { + 1: PredictedMask(mask=[[1, 2, 3], [4, 5, 6]], points=[[13, 14, 15], [16, 17, 18]]), + 2: PredictedMask(mask=[[7, 8, 9], [10, 11, 12]], points=[[19, 20, 21], [22, 23, 24]]), + }, + { + 1: PredictedMask(mask=[[1, 2, 3], [4, 5, 6]], points=[[13, 14, 15], [16, 17, 18]]), + 2: PredictedMask(mask=[[7, 8, 9], [10, 11, 12]], points=[[19, 20, 21], [22, 23, 24]]), + }, + ], + ], + ) def test_customize_outputs_inference( self, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity, + outputs: list[dict[int, PredictedMask]], ) -> None: ov_zero_shot_visual_prompting_model.training = False + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + if len(outputs) > 1: + # for multi batch testing + entity.batch_size = 2 + entity.images = [entity.images[0], entity.images[0]] + entity.imgs_info = [entity.imgs_info[0], entity.imgs_info[0]] - from model_api.models.utils import PredictedMask - - outputs = [ - {1: PredictedMask([], [4, 5, 6])}, - {2: PredictedMask([], [16, 17, 18])}, - ] - - result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) + result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, entity) assert isinstance(result, ZeroShotVisualPromptingBatchPredEntity) assert result.batch_size == len(outputs) - assert result.images == fxt_zero_shot_vpm_data_entity[1].images - assert result.imgs_info == fxt_zero_shot_vpm_data_entity[1].imgs_info + assert result.images == entity.images + assert result.imgs_info == entity.imgs_info assert isinstance(result.masks, list) assert all(isinstance(mask, tv_tensors.Mask) for mask in result.masks)