Skip to content

Commit

Permalink
Fix semantic segmentation soft prediction dtype (#2322)
Browse files Browse the repository at this point in the history
* Fix semantic segmentation soft prediction dtype

* relax ref sal vals check

---------

Co-authored-by: Songki Choi <[email protected]>
  • Loading branch information
negvet and goodsong81 authored Jul 10, 2023
1 parent 344f526 commit cfd7706
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/otx/algorithms/segmentation/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def add_prediction(
current_label_soft_prediction = soft_prediction[:, :, label_index]
if process_soft_prediction:
current_label_soft_prediction = get_activation_map(current_label_soft_prediction)
else:
current_label_soft_prediction = (current_label_soft_prediction * 255).astype(np.uint8)
result_media = ResultMediaEntity(
name=label.name,
type="soft_prediction",
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algorithms/segmentation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, dump_soft_pre
current_label_soft_prediction = soft_prediction[:, :, label_index]
if process_soft_prediction:
current_label_soft_prediction = get_activation_map(current_label_soft_prediction)
else:
current_label_soft_prediction = (current_label_soft_prediction * 255).astype(np.uint8)
result_media = ResultMediaEntity(
name=label.name,
type="soft_prediction",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def test_saliency_map_cls(self, template):
assert len(saliency_maps) == 2
assert saliency_maps[0].ndim == 3
assert saliency_maps[0].shape == (1000, 7, 7)
assert (saliency_maps[0][0][0] == self.ref_saliency_vals_cls[template.name]).all()
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_cls[template.name]) <= 1)
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_saliency_map_det(self, template):
assert len(saliency_maps) == 2
assert saliency_maps[0].ndim == 3
assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name]
assert (saliency_maps[0][0][0] == self.ref_saliency_vals_det[template.name]).all()
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_det[template.name]) <= 1)

@e2e_pytest_unit
@pytest.mark.parametrize("template", templates_det, ids=templates_det_ids)
Expand Down

0 comments on commit cfd7706

Please sign in to comment.