From b3aa3a04dba3cb6e236d9421c78c0d9361a91577 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Fri, 23 Jun 2023 14:04:23 -0700 Subject: [PATCH 1/2] fix getIOU bug --- .../ai/djl/modality/cv/output/Rectangle.java | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java index caad442f287..a3f47000328 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java @@ -86,15 +86,30 @@ public Point getPoint() { /** {@inheritDoc} */ @Override public double getIoU(BoundingBox box) { - Rectangle rec = (Rectangle) box; - // caculate intesection lrtb - double left = Math.max(getX(), rec.getX()); - double top = Math.max(getY(), rec.getY()); - double right = Math.min(getX() + getWidth(), rec.getX() + rec.getWidth()); - double bottom = Math.min(getY() + getHeight(), rec.getY() + rec.getHeight()); - double intersection = (right - left) * (bottom - top); - return intersection - / (getWidth() * getHeight() + rec.getWidth() * rec.getHeight() - intersection); + return getIoU(this, (Rectangle) box); + } + + public double getIoU(Rectangle rec1, Rectangle rec2) { + // computing area of each rectangles + double s1 = rec1.getWidth() * rec1.getHeight(); + double s2 = rec2.getWidth() * rec2.getHeight(); + + // computing the sum_area + double sumArea = s1 + s2; + + // find the each edge of intersect rectangle + double left = Math.max(rec1.getX(), rec2.getX()); + double top = Math.max(rec1.getY(), rec2.getY()); + double right = Math.min(rec1.getX() + rec1.getWidth(), rec2.getX() + rec2.getWidth()); + double bottom = Math.min(rec1.getY() + rec1.getHeight(), rec2.getY() + rec2.getHeight()); + + // judge if there is an intersect + if (left >= right || top >= bottom) { + return 0.0; + } else { + double intersect = (right - left) * (bottom - top); + return intersect / (sumArea - intersect); + } } /** From 2ad45e53d7b3badf3624cf14d6ccd8b40b3a9c35 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 26 Jun 2023 21:10:06 -0700 Subject: [PATCH 2/2] refactor Rectangle --- .../djl/modality/cv/output/BoundingBox.java | 2 +- .../ai/djl/modality/cv/output/Rectangle.java | 30 ++++++++---------- .../djl/modality/cv/output/RectangleTest.java | 31 +++++++++++++++++++ .../djl/modality/cv/output/package-info.java | 14 +++++++++ 4 files changed, 59 insertions(+), 18 deletions(-) create mode 100644 api/src/test/java/ai/djl/modality/cv/output/RectangleTest.java create mode 100644 api/src/test/java/ai/djl/modality/cv/output/package-info.java diff --git a/api/src/main/java/ai/djl/modality/cv/output/BoundingBox.java b/api/src/main/java/ai/djl/modality/cv/output/BoundingBox.java index e07627e7556..3e582f46313 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/BoundingBox.java +++ b/api/src/main/java/ai/djl/modality/cv/output/BoundingBox.java @@ -41,7 +41,7 @@ public interface BoundingBox extends Serializable { Point getPoint(); /** - * Gets the Intersection over Union (IoU) value between bounding boxes. + * Returns the Intersection over Union (IoU) value between bounding boxes. * *

Also known as Jaccard index * diff --git a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java index a3f47000328..92afc603272 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java @@ -86,30 +86,26 @@ public Point getPoint() { /** {@inheritDoc} */ @Override public double getIoU(BoundingBox box) { - return getIoU(this, (Rectangle) box); - } + Rectangle rect = box.getBounds(); - public double getIoU(Rectangle rec1, Rectangle rec2) { // computing area of each rectangles - double s1 = rec1.getWidth() * rec1.getHeight(); - double s2 = rec2.getWidth() * rec2.getHeight(); - - // computing the sum_area + double s1 = (width + 1) * (height + 1); + double s2 = (rect.getWidth() + 1) * (rect.getHeight() + 1); double sumArea = s1 + s2; - // find the each edge of intersect rectangle - double left = Math.max(rec1.getX(), rec2.getX()); - double top = Math.max(rec1.getY(), rec2.getY()); - double right = Math.min(rec1.getX() + rec1.getWidth(), rec2.getX() + rec2.getWidth()); - double bottom = Math.min(rec1.getY() + rec1.getHeight(), rec2.getY() + rec2.getHeight()); + // find each edge of intersect rectangle + double left = Math.max(getX(), rect.getX()); + double top = Math.max(getY(), rect.getY()); + double right = Math.min(getX() + getWidth(), rect.getX() + rect.getWidth()); + double bottom = Math.min(getY() + getHeight(), rect.getY() + rect.getHeight()); - // judge if there is an intersect - if (left >= right || top >= bottom) { + // judge if there is a intersect + if (left > right || top > bottom) { return 0.0; - } else { - double intersect = (right - left) * (bottom - top); - return intersect / (sumArea - intersect); } + + double intersect = (right - left + 1) * (bottom - top + 1); + return intersect / (sumArea - intersect); } /** diff --git a/api/src/test/java/ai/djl/modality/cv/output/RectangleTest.java b/api/src/test/java/ai/djl/modality/cv/output/RectangleTest.java new file mode 100644 index 00000000000..b4fc3bafaeb --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/output/RectangleTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.output; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RectangleTest { + + @Test + public void testIoU() { + BoundingBox box = new Rectangle(1, 3, 4, 5); + Rectangle rect = new Rectangle(1, 2, 3, 4); + double iou = box.getIoU(rect); + Assert.assertEquals(iou, 0.47058823529411764, 0.00001); + + rect = new Rectangle(6, 2, 3, 4); + iou = box.getIoU(rect); + Assert.assertEquals(iou, 0); + } +} diff --git a/api/src/test/java/ai/djl/modality/cv/output/package-info.java b/api/src/test/java/ai/djl/modality/cv/output/package-info.java new file mode 100644 index 00000000000..e3611e02f6f --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/output/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains tests for {@link ai.djl.modality.cv.output}. */ +package ai.djl.modality.cv.output;