diff --git a/anylabeling/views/labeling/label_converter.py b/anylabeling/views/labeling/label_converter.py index 66f626ef..e6d46ae3 100644 --- a/anylabeling/views/labeling/label_converter.py +++ b/anylabeling/views/labeling/label_converter.py @@ -321,23 +321,35 @@ def voc_to_custom(self, input_file, output_file, image_filename): difficult = "0" if obj.find("difficult") is not None: difficult = str(obj.find("difficult").text) - xmin = float(obj.find("bndbox/xmin").text) - ymin = float(obj.find("bndbox/ymin").text) - xmax = float(obj.find("bndbox/xmax").text) - ymax = float(obj.find("bndbox/ymax").text) - - shape = { - "label": label, - "description": "", - "points": [ + points = [] + if obj.find("polygon") is not None: + num_points = len(obj.find("polygon")) // 2 + for i in range(1, num_points+1): + x_tag = f"polygon/x{i}" + y_tag = f"polygon/y{i}" + x = float(obj.find(x_tag).text) + y = float(obj.find(y_tag).text) + points.append([x, y]) + shape_type = "polygon" + elif obj.find("bndbox") is not None: + xmin = float(obj.find("bndbox/xmin").text) + ymin = float(obj.find("bndbox/ymin").text) + xmax = float(obj.find("bndbox/xmax").text) + ymax = float(obj.find("bndbox/ymax").text) + points = [ [xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax], - ], + ] + shape_type = "rectangle" + shape = { + "label": label, + "description": "", + "points": points, "group_id": None, "difficult": bool(int(difficult)), - "shape_type": "rectangle", + "shape_type": shape_type, "flags": {}, } @@ -616,31 +628,44 @@ def custom_to_voc(self, input_file, output_dir): ET.SubElement(size, "width").text = str(image_width) ET.SubElement(size, "height").text = str(image_height) ET.SubElement(size, "depth").text = "3" - + source = ET.SubElement(root, "source") + ET.SubElement(source, "database").text = "https://github.com/CVHub520/X-AnyLabeling" for shape in data["shapes"]: - if shape["shape_type"] != "rectangle": - continue label = shape["label"] points = shape["points"] difficult = shape.get("difficult", False) - if len(points) == 2: - logger.warning( - "UserWarning: Diagonal vertex mode is deprecated in X-AnyLabeling release v2.2.0 or later.\n" - "Please update your code to accommodate the new four-point mode." - ) - points = rectangle_from_diagonal(points) - xmin, ymin, xmax, ymax = self.calculate_bounding_box(points) - object_elem = ET.SubElement(root, "object") ET.SubElement(object_elem, "name").text = label ET.SubElement(object_elem, "pose").text = "Unspecified" ET.SubElement(object_elem, "truncated").text = "0" + ET.SubElement(object_elem, "occluded").text = "0" ET.SubElement(object_elem, "difficult").text = str(int(difficult)) - bndbox = ET.SubElement(object_elem, "bndbox") - ET.SubElement(bndbox, "xmin").text = str(int(xmin)) - ET.SubElement(bndbox, "ymin").text = str(int(ymin)) - ET.SubElement(bndbox, "xmax").text = str(int(xmax)) - ET.SubElement(bndbox, "ymax").text = str(int(ymax)) + if shape["shape_type"] == "rectangle": + if len(points) == 2: + logger.warning( + "UserWarning: Diagonal vertex mode is deprecated in X-AnyLabeling release v2.2.0 or later.\n" + "Please update your code to accommodate the new four-point mode." + ) + points = rectangle_from_diagonal(points) + xmin, ymin, xmax, ymax = self.calculate_bounding_box(points) + bndbox = ET.SubElement(object_elem, "bndbox") + ET.SubElement(bndbox, "xmin").text = str(int(xmin)) + ET.SubElement(bndbox, "ymin").text = str(int(ymin)) + ET.SubElement(bndbox, "xmax").text = str(int(xmax)) + ET.SubElement(bndbox, "ymax").text = str(int(ymax)) + elif shape["shape_type"] == "polygon": + xmin, ymin, xmax, ymax = self.calculate_bounding_box(points) + bndbox = ET.SubElement(object_elem, "bndbox") + ET.SubElement(bndbox, "xmin").text = str(int(xmin)) + ET.SubElement(bndbox, "ymin").text = str(int(ymin)) + ET.SubElement(bndbox, "xmax").text = str(int(xmax)) + ET.SubElement(bndbox, "ymax").text = str(int(ymax)) + polygon = ET.SubElement(object_elem, "polygon") + for i, point in enumerate(points): + x_tag = ET.SubElement(polygon, f"x{i+1}") + y_tag = ET.SubElement(polygon, f"y{i+1}") + x_tag.text = str(point[0]) + y_tag.text = str(point[1]) xml_string = ET.tostring(root, encoding="utf-8") dom = minidom.parseString(xml_string)