Skip to content

Commit

Permalink
Merge branch 'develop' into add-merge-group-wf
Browse files Browse the repository at this point in the history
  • Loading branch information
yunchu authored Jul 5, 2024
2 parents 98f980f + f9808c5 commit 2ed6933
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 98 deletions.
59 changes: 40 additions & 19 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torchvision.transforms.v2 as tvt_v2
from datumaro import Polygon as dmPolygon
from torch import LongTensor, Tensor, nn
from torch import Tensor, nn
from torch.nn import functional as F # noqa: N812
from torchvision import tv_tensors
from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor
Expand Down Expand Up @@ -764,21 +764,38 @@ def _customize_outputs( # type: ignore[override]
masks: list[Mask] = []
prompts: list[Points] = []
scores: list[Tensor] = []
labels: list[LongTensor] = []
for predicted_masks, used_points in outputs:
labels: list[Tensor] = []
for idx, (predicted_masks, used_points) in enumerate(outputs):
_masks: list[Tensor] = []
_prompts: list[Tensor] = []
_scores: list[Tensor] = []
_labels: list[Tensor] = []
for label, predicted_mask in predicted_masks.items():
if len(predicted_mask) == 0:
continue
masks.append(Mask(torch.stack(predicted_mask, dim=0), dtype=torch.float32))
prompts.append(
Points(
torch.stack([p[:2] for p in used_points[label]], dim=0),
canvas_size=inputs.imgs_info[0].ori_shape,
dtype=torch.float32,
_masks.append(torch.stack(predicted_mask, dim=0))
_used_points_scores = torch.stack(used_points[label], dim=0)
_prompts.append(_used_points_scores[:, :2])
_scores.append(_used_points_scores[:, 2])
_labels.append(torch.tensor([label] * len(_used_points_scores), dtype=torch.int64, device=self.device))

if len(_masks) == 0:
masks.append(
tv_tensors.Mask(
torch.zeros((1, *inputs.imgs_info[idx].ori_shape), dtype=torch.float32, device=self.device),
),
)
scores.append(torch.stack([p[2] for p in used_points[label]], dim=0))
labels.append(torch.cat([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0))
prompts.append(
Points([], canvas_size=inputs.imgs_info[idx].ori_shape, dtype=torch.float32, device=self.device),
)
scores.append(torch.tensor([-1.0], dtype=torch.float32, device=self.device))
labels.append(torch.tensor([-1], dtype=torch.int64, device=self.device))
continue

masks.append(tv_tensors.Mask(torch.cat(_masks, dim=0)))
prompts.append(Points(torch.cat(_prompts, dim=0), canvas_size=inputs.imgs_info[idx].ori_shape))
scores.append(torch.cat(_scores, dim=0))
labels.append(torch.cat(_labels, dim=0))

return ZeroShotVisualPromptingBatchPredEntity(
batch_size=len(outputs),
Expand Down Expand Up @@ -926,15 +943,19 @@ def load_reference_info(
log.info(f"reference info saved at {path_to_directly_load} was successfully loaded.")

else:
_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)
if str(self.infer_reference_info_root) == "../.latest/train":
# for default setting
path_reference_info = (
Path(default_root_dir)
/ self.infer_reference_info_root
/ self.reference_info_dir
/ "reference_info.pt"
)
else:
# for user input
path_reference_info = self.infer_reference_info_root / self.reference_info_dir / "reference_info.pt"

if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt"
).is_file():
if path_reference_info.is_file():
reference_info = torch.load(path_reference_info)
retval = True
log.info(f"reference info saved at {path_reference_info} was successfully loaded.")
Expand Down
128 changes: 66 additions & 62 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from model_api.models.utils import PredictedMask, VisualPromptingResult
from torchmetrics import MetricCollection

from otx.core.data.module import OTXDataModule
Expand Down Expand Up @@ -139,23 +140,6 @@ def _inference_step_for_zero_shot(
for ett in converted_entities["preds"]
]
_target = converted_entities["target"]

# match #_preds and #_target
if len(_preds) > len(_target):
# interpolate _target
num_diff = len(_preds) - len(_target)
for idx in range(num_diff):
_target.append(_target[idx])
elif len(_preds) < len(_target):
num_diff = len(_target) - len(_preds)
pad_prediction = {
"masks": torch.zeros_like(_target[0]["masks"], dtype=_target[0]["masks"].dtype),
"labels": torch.zeros_like(_target[0]["labels"], dtype=_target[0]["labels"].dtype),
"scores": torch.zeros(len(_target[0]["labels"]), dtype=torch.float32),
} # for empty prediction
for idx in range(num_diff):
_preds.append(_preds[idx] if idx < len(_preds) else pad_prediction)

_metric.update(preds=_preds, target=_target)
elif _name in ["iou", "f1-score", "dice"]:
# BinaryJaccardIndex, BinaryF1Score, Dice
Expand Down Expand Up @@ -345,6 +329,7 @@ def on_test_start(self) -> None:
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.training = True
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
Expand All @@ -362,6 +347,7 @@ def on_predict_start(self) -> None:
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.training = True
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
Expand Down Expand Up @@ -530,7 +516,7 @@ def forward(
)

images, batch_prompts = self._customize_inputs(inputs)
outputs: list[Any] = []
outputs: list[VisualPromptingResult] = []
for image, prompt in zip(images, batch_prompts):
outputs.append(self.model(image, **prompt))

Expand Down Expand Up @@ -576,31 +562,28 @@ def _customize_inputs( # type: ignore[override]

def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
outputs: list[VisualPromptingResult],
inputs: VisualPromptingBatchDataEntity, # type: ignore[override]
) -> VisualPromptingBatchPredEntity:
"""Customize OTX output batch data entity if needed for model."""
masks: list[tv_tensors.Mask] = []
scores: list[torch.Tensor] = []
scores: list[Tensor] = []
labels: list[Tensor] = []
for image_output in outputs:
masks.extend(
[
torch.as_tensor(hard_prediction, device=self.device)
for hard_prediction in image_output.hard_predictions
],
)
scores.extend([torch.as_tensor(score, device=self.device) for score in image_output.scores])
masks.append(tv_tensors.Mask(np.concatenate(image_output.hard_predictions), device=self.device))
scores.append(torch.as_tensor(np.concatenate(image_output.scores), device=self.device))
labels.append(torch.as_tensor(image_output.labels, device=self.device))

return VisualPromptingBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[torch.cat(scores, dim=0)],
masks=[tv_tensors.Mask(torch.cat(masks, dim=0))],
scores=scores,
masks=masks,
polygons=[],
points=[],
bboxes=[],
labels=[torch.cat(list(labels.values())) for labels in inputs.labels],
labels=labels,
)

def optimize( # type: ignore[override]
Expand Down Expand Up @@ -884,11 +867,11 @@ def infer(
inputs: ZeroShotVisualPromptingBatchDataEntity,
reference_feats: np.ndarray,
used_indices: np.ndarray,
) -> list[list[defaultdict[int, list]]]:
) -> list[dict[int, PredictedMask]]:
"""`Infer` for target predictions."""
images, _ = self._customize_inputs(inputs)

total_results: list[list[defaultdict[int, list]]] = []
total_results: list[dict[int, PredictedMask]] = []
for image in images:
result = self.model(image, VisualPromptingFeatures(reference_feats, used_indices))
total_results.append(result.data)
Expand All @@ -898,7 +881,7 @@ def infer(
def forward( # type: ignore[override]
self,
inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override]
) -> ZeroShotVisualPromptingBatchPredEntity:
) -> tuple[dict[str, np.ndarray], list[np.ndarray]] | ZeroShotVisualPromptingBatchPredEntity:
"""Model forward function."""
kwargs: dict[str, Any] = {}
fn = self.learn if self.training else self.infer
Expand Down Expand Up @@ -959,42 +942,57 @@ def _customize_inputs( # type: ignore[override]

def _customize_outputs( # type: ignore[override]
self,
outputs: Any, # noqa: ANN401
outputs: tuple[dict[str, np.ndarray], list[np.ndarray]] | list[dict[int, PredictedMask]],
inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override]
) -> ZeroShotVisualPromptingBatchPredEntity:
) -> tuple[dict[str, np.ndarray], list[np.ndarray]] | ZeroShotVisualPromptingBatchPredEntity:
"""Customize OTX output batch data entity if needed for model."""
if self.training:
if self.training and isinstance(outputs, tuple):
return outputs

masks: list[tv_tensors.Mask] = []
prompts: list[Points] = []
scores: list[torch.Tensor] = []
labels: list[torch.LongTensor] = []
for output in outputs:
scores: list[Tensor] = []
labels: list[Tensor] = []
for idx, output in enumerate(outputs):
if not isinstance(output, dict):
continue
_masks: list[np.ndarray] = []
_prompts: list[np.ndarray] = []
_scores: list[np.ndarray] = []
_labels: list[np.ndarray] = []
for label, predicted_mask in output.items():
if len(predicted_mask.mask) == 0:
continue
_masks.append(np.stack(predicted_mask.mask, axis=0))
_used_points_scores = np.stack(predicted_mask.points, axis=0)
_prompts.append(_used_points_scores[:, :2])
_scores.append(_used_points_scores[:, 2])
_labels.append(np.array([label] * len(_used_points_scores)))

if len(_masks) == 0:
masks.append(
tv_tensors.Mask(
torch.stack([torch.as_tensor(m) for m in predicted_mask.mask], dim=0),
dtype=torch.float32,
device=self.device,
torch.zeros((1, *inputs.imgs_info[idx].ori_shape), dtype=torch.float32, device=self.device),
),
)
prompts.append(
Points(
torch.stack([torch.as_tensor(p[:2]) for p in predicted_mask.points], dim=0),
canvas_size=inputs.imgs_info[0].ori_shape,
dtype=torch.float32,
device=self.device,
),
)
scores.append(
torch.stack([torch.as_tensor(p[2]) for p in predicted_mask.points], dim=0).to(self.device),
)
labels.append(
torch.cat([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device),
Points([], canvas_size=inputs.imgs_info[idx].ori_shape, dtype=torch.float32, device=self.device),
)
scores.append(torch.tensor([-1.0], dtype=torch.float32, device=self.device))
labels.append(torch.tensor([-1], dtype=torch.int64, device=self.device))
continue

masks.append(tv_tensors.Mask(np.concatenate(_masks, axis=0), dtype=torch.float32, device=self.device))
prompts.append(
Points(
np.concatenate(_prompts, axis=0),
canvas_size=inputs.imgs_info[idx].ori_shape,
dtype=torch.float32,
device=self.device,
),
)
scores.append(torch.as_tensor(np.concatenate(_scores, axis=0), dtype=torch.float32, device=self.device))
labels.append(torch.as_tensor(np.concatenate(_labels, axis=0), dtype=torch.int64, device=self.device))

return ZeroShotVisualPromptingBatchPredEntity(
batch_size=len(outputs),
Expand Down Expand Up @@ -1193,15 +1191,19 @@ def _load_and_assign_reference_info(path: Path) -> bool:
# if `path_to_directly_load` is given, forcely load
return _load_and_assign_reference_info(path_to_directly_load)

_infer_reference_info_root: Path = (
self.infer_reference_info_root
if self.infer_reference_info_root == self.infer_reference_info_root.absolute()
else Path(default_root_dir) / self.infer_reference_info_root
)
if str(self.infer_reference_info_root) == "../.latest/train":
# for default setting
path_reference_info = (
Path(default_root_dir)
/ self.infer_reference_info_root
/ self.reference_info_dir
/ "reference_info.pickle"
)
else:
# for user input
path_reference_info = self.infer_reference_info_root / self.reference_info_dir / "reference_info.pickle"

if (
path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle"
).is_file():
if path_reference_info.is_file():
return _load_and_assign_reference_info(path_reference_info)

return False
Expand All @@ -1222,6 +1224,7 @@ def on_test_start(self) -> None:
if not self.load_reference_info(self.trainer.default_root_dir):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.training = True
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
Expand All @@ -1238,6 +1241,7 @@ def on_predict_start(self) -> None:
if not self.load_reference_info(self.trainer.default_root_dir):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.training = True
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
async_inference: False
use_throughput_mode: True
reference_info_dir: reference_infos
infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location
infer_reference_info_root: ../.latest/train
save_outputs: True

engine:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ model:
default_threshold_target: 0.65
save_outputs: True
reference_info_dir: reference_infos
infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location
infer_reference_info_root: ../.latest/train
# options
use_stability_score: False
return_single_mask: False
Expand Down
2 changes: 1 addition & 1 deletion src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ model:
default_threshold_target: 0.65
save_outputs: True
reference_info_dir: reference_infos
infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location
infer_reference_info_root: ../.latest/train
# options
use_stability_score: False
return_single_mask: False
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/api/test_auto_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_auto_configuration(
device=fxt_accelerator,
)
if task.lower() == "zero_shot_visual_prompting":
engine.model.infer_reference_info_root = Path()
engine.model.infer_reference_info_root = Path(tmp_path_train)
# update litmodule.hparams to reflect changed hparams
engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)})

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/api/test_engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_engine_from_config(
device=fxt_accelerator,
)
if task.lower() == "zero_shot_visual_prompting":
engine.model.infer_reference_info_root = Path()
engine.model.infer_reference_info_root = Path(tmp_path_train)
# update litmodule.hparams to reflect changed hparams
engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)})

Expand Down Expand Up @@ -97,7 +97,7 @@ def test_engine_from_config(
model_name=str(exported_model_path["decoder"]),
label_info=engine.datamodule.label_info,
)
engine.model.infer_reference_info_root = Path()
engine.model.infer_reference_info_root = Path(tmp_path_train)
# update litmodule.hparams to reflect changed hparams
engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)})
test_metric_from_ov_model = engine.test(checkpoint=exported_model_path["decoder"], accelerator="cpu")
Expand Down
Loading

0 comments on commit 2ed6933

Please sign in to comment.