Skip to content

Commit

Permalink
Fix tiling XAI out of range (#3943)
Browse files Browse the repository at this point in the history
- Fix tile merge XAI out of range
  • Loading branch information
eugene123tw authored Sep 9, 2024
1 parent 1d319cd commit aaa2765
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> D
inputs.imgs_info,
self.num_classes,
self.tile_config,
self.explain_mode,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntit
inputs.imgs_info,
self.num_classes,
self.tile_config,
self.explain_mode,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
Expand Down
12 changes: 7 additions & 5 deletions src/otx/core/utils/tile_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,25 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
Args:
img_infos (list[ImageInfo]): Original image information before tiling.
iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.45.
max_num_instances (int, optional): Maximum number of instances to keep. Defaults to 500.
num_classes (int): Number of classes.
tile_config (TileConfig): Tile configuration.
explain_mode (bool): Whether or not tiles have explain features. Default: False.
"""

def __init__(
self,
img_infos: list[ImageInfo],
num_classes: int,
tile_config: TileConfig,
explain_mode: bool = False,
) -> None:
self.img_infos = img_infos
self.num_classes = num_classes
self.tile_size = tile_config.tile_size
self.iou_threshold = tile_config.iou_threshold
self.max_num_instances = tile_config.max_num_instances
self.with_full_img = tile_config.with_full_img
self.explain_mode = explain_mode

@abstractmethod
def _merge_entities(
Expand Down Expand Up @@ -115,7 +117,7 @@ def merge(
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = len(batch_tile_preds[0].feature_vector) > 0
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
batch_size = tile_preds.batch_size
Expand Down Expand Up @@ -315,7 +317,7 @@ def merge(
"""
entities_to_merge = defaultdict(list)
img_ids = []
explain_mode = len(batch_tile_preds[0].feature_vector) > 0
explain_mode = self.explain_mode

for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]
Expand Down

0 comments on commit aaa2765

Please sign in to comment.