diff --git a/configs/_base_/uwmgi.yml b/configs/_base_/uwmgi.yml
new file mode 100644
index 0000000000..b3bb67eced
--- /dev/null
+++ b/configs/_base_/uwmgi.yml
@@ -0,0 +1,54 @@
+batch_size: 8
+iters: 160000
+
+train_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: RandomHorizontalFlip
+ - type: RandomVerticalFlip
+ - type: RandomDistort
+ brightness_range: 0.4
+ contrast_range: 0.4
+ saturation_range: 0.4
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ train_path: data/UWMGI/train.txt
+ mode: train
+
+val_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ val_path: data/UWMGI/val.txt
+ mode: val
+
+optimizer:
+ type: SGD
+ momentum: 0.9
+ weight_decay: 4.0e-5
+
+lr_scheduler:
+ type: PolynomialDecay
+ learning_rate: 0.001
+ end_lr: 0
+ power: 0.9
+
+loss:
+ types:
+ - type: MixedLoss
+ losses:
+ - type: BCELoss
+ - type: LovaszHingeLoss
+ coef: [0.5, 0.5]
+ coef: [1]
diff --git a/configs/multilabelseg/README.md b/configs/multilabelseg/README.md
new file mode 100644
index 0000000000..9ae964adfc
--- /dev/null
+++ b/configs/multilabelseg/README.md
@@ -0,0 +1,139 @@
+English | [简体中文](README_cn.md)
+
+# Multi-label semantic segmentation based on PaddleSeg
+
+## 1. introduction
+
+Multi-label semantic segmentation is an image segmentation task that aims to assign each pixel in an image to multiple categories, rather than just one category. This can better express complex information in the image, such as overlapping, occlusion, boundaries, etc. of different objects. Multi label semantic segmentation has many application scenarios, such as medical image analysis, remote sensing image interpretation, autonomous driving, and so on.
+
+
+
+
+
+
+
++ *The above effect shows the inference results obtained from the model trained using images in the [UWMGI](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/) dataset*
+
+## 2. Supported models and loss functions
+
+| Model | Loss |
+|:-------------------------------------------------------------------------------------------:|:------------------------:|
+| DeepLabV3, DeepLabV3P, MobileSeg,
PP-LiteSeg, PP-MobileSeg, UNet,
Unet++, Unet+++ | BCELoss, LovaszHingeLoss |
+
++ *The above are the confirmed supported models and loss functions, with a larger actual support range.*
+
+## 3. Sample Tutorial
+
+The following will take the **[UWMGI](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/)** multi-label semantic segmentation dataset and the **[PP-MobileSeg](../pp_mobileseg/README.md)** model as examples.
+
+### 3.1 Data Preparation
+In the single label semantic segmentation task, the shape of the annotated grayscale image is **(img_h, img_w)**, and the index value of the category is represented by grayscale values.
+
+In the multi-label semantic segmentation task, the shape of the annotated grayscale image is **(img_h, num_classes x img_w)**, which means that the corresponding binary annotations of each category are sequentially concatenated in the horizontal direction.
+
+Download the raw data compression package of the UWMGI dataset and convert it to a format supported by PaddleSeg's [Dataset](../../paddleseg/datasets/dataset.py) API using the provided script.
+```shell
+wget https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/27923/3495119/bundle/archive.zip?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1693533809&Signature=ThCLjIYxSXfk85lCbZ5Cz2Ta4g8AjwJv0%2FgRpqpchlZLLYxk3XRnrZqappboha0moC7FuqllpwlLfCambQMbKoUjCLylVQqF0mEsn0IaJdYwprWYY%2F4FJDT2lG0HdQfAxJxlUPonXeZyZ4pZjOrrVEMprxuiIcM2kpGk35h7ry5ajkmdQbYmNQHFAJK2iO%2F4a8%2F543zhZRWsZZVbQJHid%2BjfO6ilLWiAGnMFpx4Sh2B01TUde9hBCwpxgJv55Gs0a4Z1KNsBRly6uqwgZFYfUBAejySx4RxFB7KEuRowDYuoaRT8NhSkzT2i7qqdZjgHxkFZJpRMUlDcf1RSJVkvEA%3D%3D&response-content-disposition=attachment%3B+filename%3Duw-madison-gi-tract-image-segmentation.zip
+python tools/data/convert_multilabel.py \
+ --dataset_type uwmgi \
+ --zip_input ./uw-madison-gi-tract-image-segmentation.zip \
+ --output ./data/UWMGI/ \
+ --train_proportion 0.8 \
+ --val_proportion 0.2
+# optional
+rm ./uw-madison-gi-tract-image-segmentation.zip
+```
+
+The structure of the UWMGI dataset after conversion is as follows:
+```
+UWMGI
+ |
+ |--images
+ | |--train
+ | | |--*.jpg
+ | | |--...
+ | |
+ | |--val
+ | | |--*.jpg
+ | | |--...
+ |
+ |--annotations
+ | |--train
+ | | |--*.jpg
+ | | |--...
+ | |
+ | |--val
+ | | |--*.jpg
+ | | |--...
+ |
+ |--train.txt
+ |
+ |--val.txt
+```
+
+The divided training dataset and evaluation dataset can be configured as follows:
+```yaml
+train_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: RandomHorizontalFlip
+ - type: RandomVerticalFlip
+ - type: RandomDistort
+ brightness_range: 0.4
+ contrast_range: 0.4
+ saturation_range: 0.4
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ train_path: data/UWMGI/train.txt
+ mode: train
+
+val_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ val_path: data/UWMGI/val.txt
+ mode: val
+```
+
+### 3.2 Training
+```shell
+python tools/train.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --save_dir output/pp_mobileseg_tiny_uwmgi_256x256_160k \
+ --num_workers 8 \
+ --do_eval \
+ --use_vdl \
+ --save_interval 2000 \
+ --use_multilabel
+```
++ *When using `--do_eval`must be added `--use_multilabel` parameter is used to adapt the evaluation in multi-label mode.*
+
+### 3.3 Evaluation
+```shell
+python tools/val.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --model_path output/pp_mobileseg_tiny_uwmgi_256x256_160k/best_model/model.pdparams \
+ --use_multilabel
+```
++ *Must add `--use_multilabel` when evaluating the model to adapt the evaluation in multi-label mode.*
+
+### 3.4 Inference
+```shell
+python tools/predict.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --model_path output/pp_mobileseg_tiny_uwmgi_256x256_160k/best_model/model.pdparams \
+ --image_path data/UWMGI/images/val/case122_day18_slice_0089.jpg \
+ --use_multilabel
+```
++ *When executing a prediction, it is necessary to add `--use_multilabel` parameter is used to adapt visualization in multi-label mode.*
\ No newline at end of file
diff --git a/configs/multilabelseg/README_cn.md b/configs/multilabelseg/README_cn.md
new file mode 100644
index 0000000000..2342d1ac0b
--- /dev/null
+++ b/configs/multilabelseg/README_cn.md
@@ -0,0 +1,139 @@
+[English](README.md) | 简体中文
+
+# 基于 PaddleSeg 的多标签语义分割
+
+## 1. 简介
+
+多标签语义分割是一种图像分割任务,它的目的是将图像中的每个像素分配到多个类别中,而不是只有一个类别。这样可以更好地表达图像中的复杂信息,例如不同物体的重叠、遮挡、边界等。多标签语义分割有许多应用场景,例如医学图像分析、遥感图像解译、自动驾驶等。
+
+
+
+
+
+
+
++ *以上效果展示图基于 [UWMGI](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/)数据集中的图片使用训练的模型所得到的推理结果。*
+
+## 2. 已支持的模型和损失函数
+
+| Model | Loss |
+|:-------------------------------------------------------------------------------------------:|:------------------------:|
+| DeepLabV3, DeepLabV3P, MobileSeg,
PP-LiteSeg, PP-MobileSeg, UNet,
Unet++, Unet+++ | BCELoss, LovaszHingeLoss |
+
++ *以上为确认支持的模型和损失函数,实际支持范围更大。*
+
+## 3. 示例教程
+
+如下将以 **[UWMGI](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/)** 多标签语义分割数据集和 **[PP-MobileSeg](../pp_mobileseg/README.md)** 模型为例。
+
+### 3.1 数据准备
+在单标签多类别语义分割任务中,标注灰度图的形状为 **(img_h, img_w)**, 并以灰度值来表示类别的索引值。
+
+在多标签语义分割任务中,标注灰度图的形状为 **(img_h, num_classes x img_w)**, 即将各个类别对应二值标注按顺序拼接在水平方向上。
+
+下载UWMGI数据集的原始数据压缩包,并使用提供的脚本转换为PaddleSeg的[Dataset](../../paddleseg/datasets/dataset.py) API支持的格式。
+```shell
+wget https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/27923/3495119/bundle/archive.zip?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1693533809&Signature=ThCLjIYxSXfk85lCbZ5Cz2Ta4g8AjwJv0%2FgRpqpchlZLLYxk3XRnrZqappboha0moC7FuqllpwlLfCambQMbKoUjCLylVQqF0mEsn0IaJdYwprWYY%2F4FJDT2lG0HdQfAxJxlUPonXeZyZ4pZjOrrVEMprxuiIcM2kpGk35h7ry5ajkmdQbYmNQHFAJK2iO%2F4a8%2F543zhZRWsZZVbQJHid%2BjfO6ilLWiAGnMFpx4Sh2B01TUde9hBCwpxgJv55Gs0a4Z1KNsBRly6uqwgZFYfUBAejySx4RxFB7KEuRowDYuoaRT8NhSkzT2i7qqdZjgHxkFZJpRMUlDcf1RSJVkvEA%3D%3D&response-content-disposition=attachment%3B+filename%3Duw-madison-gi-tract-image-segmentation.zip
+python tools/data/convert_multilabel.py \
+ --dataset_type uwmgi \
+ --zip_input ./uw-madison-gi-tract-image-segmentation.zip \
+ --output ./data/UWMGI/ \
+ --train_proportion 0.8 \
+ --val_proportion 0.2
+# 可选
+rm ./uw-madison-gi-tract-image-segmentation.zip
+```
+
+转换完成后的UWMGI数据集结构如下:
+```
+UWMGI
+ |
+ |--images
+ | |--train
+ | | |--*.jpg
+ | | |--...
+ | |
+ | |--val
+ | | |--*.jpg
+ | | |--...
+ |
+ |--annotations
+ | |--train
+ | | |--*.jpg
+ | | |--...
+ | |
+ | |--val
+ | | |--*.jpg
+ | | |--...
+ |
+ |--train.txt
+ |
+ |--val.txt
+```
+
+划分好的训练数据集和评估数据集可按如下方式进行配置:
+```yaml
+train_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: RandomHorizontalFlip
+ - type: RandomVerticalFlip
+ - type: RandomDistort
+ brightness_range: 0.4
+ contrast_range: 0.4
+ saturation_range: 0.4
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ train_path: data/UWMGI/train.txt
+ mode: train
+
+val_dataset:
+ type: Dataset
+ dataset_root: data/UWMGI
+ transforms:
+ - type: Resize
+ target_size: [256, 256]
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 3
+ val_path: data/UWMGI/val.txt
+ mode: val
+```
+
+### 3.2 训练模型
+```shell
+python tools/train.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --save_dir output/pp_mobileseg_tiny_uwmgi_256x256_160k \
+ --num_workers 8 \
+ --do_eval \
+ --use_vdl \
+ --save_interval 2000 \
+ --use_multilabel
+```
++ *当使用`--do_eval`必须添加`--use_multilabel`参数来适配多标签模式下的评估。*
+
+### 3.3 评估模型
+```shell
+python tools/val.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --model_path output/pp_mobileseg_tiny_uwmgi_256x256_160k/best_model/model.pdparams \
+ --use_multilabel
+```
++ *评估模型时必须添加`--use_multilabel`参数来适配多标签模式下的评估。*
+
+### 3.4 执行预测
+```shell
+python tools/predict.py \
+ --config configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml \
+ --model_path output/pp_mobileseg_tiny_uwmgi_256x256_160k/best_model/model.pdparams \
+ --image_path data/UWMGI/images/val/case122_day18_slice_0089.jpg \
+ --use_multilabel
+```
++ *执行预测时必须添加`--use_multilabel`参数来适配多标签模式下的可视化。*
\ No newline at end of file
diff --git a/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_160k.yml b/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_160k.yml
new file mode 100644
index 0000000000..a50f3e82e1
--- /dev/null
+++ b/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_160k.yml
@@ -0,0 +1,18 @@
+_base_: '../_base_/uwmgi.yml'
+
+batch_size: 8
+iters: 160000
+
+model:
+ type: DeepLabV3
+ num_classes: 3
+ backbone:
+ type: ResNet50_vd
+ output_stride: 8
+ multi_grid: [1, 2, 4]
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
+ backbone_indices: [3]
+ aspp_ratios: [1, 12, 24, 36]
+ aspp_out_channels: 256
+ align_corners: False
+ pretrained: null
diff --git a/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_80k_withaux.yml b/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_80k_withaux.yml
new file mode 100644
index 0000000000..e2e7797120
--- /dev/null
+++ b/configs/multilabelseg/deeplabv3_resnet50_os8_uwmgi_256x256_80k_withaux.yml
@@ -0,0 +1,44 @@
+_base_: '../_base_/uwmgi.yml'
+
+batch_size: 8
+iters: 80000
+
+train_dataset:
+ transforms:
+ - type: AddMultiLabelAuxiliaryCategory
+ - type: Resize
+ target_size: [256, 256]
+ - type: RandomHorizontalFlip
+ - type: RandomVerticalFlip
+ - type: RandomDistort
+ brightness_range: 0.4
+ contrast_range: 0.4
+ saturation_range: 0.4
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 4
+
+val_dataset:
+ transforms:
+ - type: AddMultiLabelAuxiliaryCategory
+ - type: Resize
+ target_size: [256, 256]
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 4
+
+model:
+ type: DeepLabV3
+ num_classes: 4
+ backbone:
+ type: ResNet50_vd
+ output_stride: 8
+ multi_grid: [1, 2, 4]
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
+ backbone_indices: [3]
+ aspp_ratios: [1, 12, 24, 36]
+ aspp_out_channels: 256
+ align_corners: False
+ pretrained: null
diff --git a/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml b/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml
new file mode 100644
index 0000000000..b1cf55d50c
--- /dev/null
+++ b/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_160k.yml
@@ -0,0 +1,34 @@
+_base_: '../_base_/uwmgi.yml'
+
+batch_size: 32
+iters: 160000
+
+optimizer:
+ _inherited_: False
+ type: AdamW
+ weight_decay: 0.01
+ custom_cfg:
+ - name: pos_embed
+ weight_decay_mult: 0.0
+ - name: head
+ lr_multi: 10.0
+ - name: bn
+ weight_decay_mult: 0.0
+
+lr_scheduler:
+ type: PolynomialDecay
+ learning_rate: 0.0006
+ end_lr: 0
+ power: 1.0
+ warmup_iters: 1500
+ warmup_start_lr: 1.0e-6
+
+model:
+ type: PPMobileSeg
+ num_classes: 3
+ backbone:
+ type: MobileSeg_Tiny
+ inj_type: AAM
+ out_feat_chs: [32, 64, 128]
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_tiny_pretrain/model.pdparams
+ upsample: intepolate # During exportation, you need to change it to vim for using VIM
diff --git a/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_80k_withaux.yml b/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_80k_withaux.yml
new file mode 100644
index 0000000000..f8fb4bb3dc
--- /dev/null
+++ b/configs/multilabelseg/pp_mobileseg_tiny_uwmgi_256x256_80k_withaux.yml
@@ -0,0 +1,60 @@
+_base_: '../_base_/uwmgi.yml'
+
+batch_size: 32
+iters: 80000
+
+train_dataset:
+ transforms:
+ - type: AddMultiLabelAuxiliaryCategory
+ - type: Resize
+ target_size: [256, 256]
+ - type: RandomHorizontalFlip
+ - type: RandomVerticalFlip
+ - type: RandomDistort
+ brightness_range: 0.4
+ contrast_range: 0.4
+ saturation_range: 0.4
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 4
+
+val_dataset:
+ transforms:
+ - type: AddMultiLabelAuxiliaryCategory
+ - type: Resize
+ target_size: [256, 256]
+ - type: Normalize
+ mean: [0.0, 0.0, 0.0]
+ std: [1.0, 1.0, 1.0]
+ num_classes: 4
+
+optimizer:
+ _inherited_: False
+ type: AdamW
+ weight_decay: 0.01
+ custom_cfg:
+ - name: pos_embed
+ weight_decay_mult: 0.0
+ - name: head
+ lr_multi: 10.0
+ - name: bn
+ weight_decay_mult: 0.0
+
+lr_scheduler:
+ type: PolynomialDecay
+ learning_rate: 0.0006
+ end_lr: 0
+ power: 1.0
+ warmup_iters: 1500
+ warmup_start_lr: 1.0e-6
+
+model:
+ type: PPMobileSeg
+ num_classes: 3
+ backbone:
+ type: MobileSeg_Tiny
+ inj_type: AAM
+ out_feat_chs: [32, 64, 128]
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_tiny_pretrain/model.pdparams
+ upsample: intepolate # During exportation, you need to change it to vim for using VIM
diff --git a/paddleseg/core/infer.py b/paddleseg/core/infer.py
index d5df03e86e..66a529164b 100644
--- a/paddleseg/core/infer.py
+++ b/paddleseg/core/infer.py
@@ -136,7 +136,8 @@ def inference(model,
trans_info=None,
is_slide=False,
stride=None,
- crop_size=None):
+ crop_size=None,
+ use_multilabel=False):
"""
Inference for image.
@@ -147,6 +148,7 @@ def inference(model,
is_slide (bool): Whether to infer by sliding window. Default: False.
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
Tensor: If ori_shape is not None, a prediction with shape (1, 1, h, w) is returned.
@@ -167,7 +169,10 @@ def inference(model,
logit = logit.transpose((0, 3, 1, 2))
if trans_info is not None:
logit = reverse_transform(logit, trans_info, mode='bilinear')
- pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
+ if not use_multilabel:
+ pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
+ else:
+ pred = (F.sigmoid(logit) > 0.5).astype('int32')
return pred, logit
else:
return logit
@@ -181,7 +186,8 @@ def aug_inference(model,
flip_vertical=False,
is_slide=False,
stride=None,
- crop_size=None):
+ crop_size=None,
+ use_multilabel=False):
"""
Infer with augmentation.
@@ -195,6 +201,7 @@ def aug_inference(model,
is_slide (bool): Whether to infer by sliding wimdow. Default: False.
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
Tensor: Prediction of image with shape (1, 1, h, w) is returned.
@@ -229,6 +236,9 @@ def aug_inference(model,
# comparable to single-scale logits
final_logit /= num_augs
final_logit = reverse_transform(final_logit, trans_info, mode='bilinear')
- pred = paddle.argmax(final_logit, axis=1, keepdim=True, dtype='int32')
+ if not use_multilabel:
+ pred = paddle.argmax(final_logit, axis=1, keepdim=True, dtype='int32')
+ else:
+ pred = (F.sigmoid(final_logit) > 0.5).astype('int32')
- return pred, final_logit
+ return pred, final_logit
\ No newline at end of file
diff --git a/paddleseg/core/predict.py b/paddleseg/core/predict.py
index 73d2f866de..016a93878c 100644
--- a/paddleseg/core/predict.py
+++ b/paddleseg/core/predict.py
@@ -58,7 +58,8 @@ def predict(model,
is_slide=False,
stride=None,
crop_size=None,
- custom_color=None):
+ custom_color=None,
+ use_multilabel=False):
"""
predict and visualize the image_list.
@@ -79,6 +80,7 @@ def predict(model,
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
It should be provided when `is_slide` is True.
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map.
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
"""
utils.utils.load_entire_model(model, model_path)
@@ -110,7 +112,8 @@ def predict(model,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
else:
pred, _ = infer.inference(
model,
@@ -118,7 +121,8 @@ def predict(model,
trans_info=data['trans_info'],
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
pred = paddle.squeeze(pred)
pred = pred.numpy().astype('uint8')
@@ -132,13 +136,14 @@ def predict(model,
# save added image
added_image = utils.visualize.visualize(
- im_path, pred, color_map, weight=0.6)
+ im_path, pred, color_map, weight=0.6, use_multilabel=use_multilabel)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save pseudo color prediction
- pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map)
+ pred_mask = utils.visualize.get_pseudo_color_map(
+ pred, color_map, use_multilabel=use_multilabel)
pred_saved_path = os.path.join(
pred_saved_dir, os.path.splitext(im_file)[0] + ".png")
mkdir(pred_saved_path)
diff --git a/paddleseg/core/val.py b/paddleseg/core/val.py
index 80a820b6bc..437c9acf4f 100644
--- a/paddleseg/core/val.py
+++ b/paddleseg/core/val.py
@@ -38,7 +38,8 @@ def evaluate(model,
amp_level='O1',
num_workers=0,
print_detail=True,
- auc_roc=False):
+ auc_roc=False,
+ use_multilabel=False):
"""
Launch evalution.
@@ -59,6 +60,7 @@ def evaluate(model,
num_workers (int, optional): Num workers for data loader. Default: 0.
print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True.
auc_roc(bool, optional): whether add auc_roc metric
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
float: The mIoU of validation datasets.
@@ -120,7 +122,8 @@ def evaluate(model,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
else:
pred, logits = infer.aug_inference(
model,
@@ -131,7 +134,8 @@ def evaluate(model,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
else:
if precision == 'fp16':
with paddle.amp.auto_cast(
@@ -148,7 +152,8 @@ def evaluate(model,
trans_info=data['trans_info'],
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
else:
pred, logits = infer.inference(
model,
@@ -156,13 +161,15 @@ def evaluate(model,
trans_info=data['trans_info'],
is_slide=is_slide,
stride=stride,
- crop_size=crop_size)
+ crop_size=crop_size,
+ use_multilabel=use_multilabel)
intersect_area, pred_area, label_area = metrics.calculate_area(
pred,
label,
eval_dataset.num_classes,
- ignore_index=eval_dataset.ignore_index)
+ ignore_index=eval_dataset.ignore_index,
+ use_multilabel=use_multilabel)
# Gather from all ranks
if nranks > 1:
diff --git a/paddleseg/datasets/dataset.py b/paddleseg/datasets/dataset.py
index d518f5b4f8..f2a0c8593f 100644
--- a/paddleseg/datasets/dataset.py
+++ b/paddleseg/datasets/dataset.py
@@ -155,7 +155,8 @@ def __getitem__(self, idx):
data['gt_fields'] = []
if self.mode == 'val':
data = self.transforms(data)
- data['label'] = data['label'][np.newaxis, :, :]
+ if data['label'].ndim == 2:
+ data['label'] = data['label'][np.newaxis, :, :]
else:
data['gt_fields'].append('label')
diff --git a/paddleseg/transforms/transforms.py b/paddleseg/transforms/transforms.py
index eb298a4a6c..05d11da4e7 100644
--- a/paddleseg/transforms/transforms.py
+++ b/paddleseg/transforms/transforms.py
@@ -82,6 +82,11 @@ def __call__(self, data):
if 'label' in data.keys() and isinstance(data['label'], str):
data['label'] = np.asarray(Image.open(data['label']))
+ img_h, img_w = data['img'].shape[:2]
+ if data['label'].shape[0] != img_h:
+ data['label'] = data['label'].reshape([-1, img_h, img_w]).transpose([1, 2, 0])
+ elif data['label'].shape[1] != img_w:
+ data['label'] = data['label'].reshape([img_h, -1, img_w]).transpose([0, 2, 1])
# the `trans_info` will save the process of image shape, and will be used in evaluation and prediction.
if 'trans_info' not in data.keys():
@@ -93,6 +98,8 @@ def __call__(self, data):
if data['img'].ndim == 2:
data['img'] = data['img'][..., np.newaxis]
data['img'] = np.transpose(data['img'], (2, 0, 1))
+ if 'label' in data and data['label'].ndim == 3:
+ data['label'] = np.transpose(data['label'], (2, 0, 1))
return data
@@ -1224,3 +1231,17 @@ def __call__(self, data):
data['instances'] = instances
return data
+
+
+@manager.TRANSFORMS.add_component
+class AddMultiLabelAuxiliaryCategory:
+ """
+ Add a complementary set of unions labeled with corresponding mask for other categories as an auxiliary category.
+ """
+
+ def __call__(self, data):
+ if 'label' in data:
+ aux_label = (data['label'].sum(axis=-1, keepdims=True) == 0).astype('uint8')
+ data['label'] = np.concatenate([aux_label, data['label']], axis=-1)
+
+ return data
diff --git a/paddleseg/utils/metrics.py b/paddleseg/utils/metrics.py
index 5327a464f0..fd7b0c3ba4 100644
--- a/paddleseg/utils/metrics.py
+++ b/paddleseg/utils/metrics.py
@@ -18,7 +18,7 @@
import sklearn.metrics as skmetrics
-def calculate_area(pred, label, num_classes, ignore_index=255):
+def calculate_area(pred, label, num_classes, ignore_index=255, use_multilabel=False):
"""
Calculate intersect, prediction and label area
@@ -27,36 +27,42 @@ def calculate_area(pred, label, num_classes, ignore_index=255):
label (Tensor): The ground truth of image.
num_classes (int): The unique number of target classes.
ignore_index (int): Specifies a target value that is ignored. Default: 255.
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
Tensor: The intersection area of prediction and the ground on all class.
Tensor: The prediction area on all class.
Tensor: The ground truth area on all class
"""
- if len(pred.shape) == 4:
- pred = paddle.squeeze(pred, axis=1)
- if len(label.shape) == 4:
- label = paddle.squeeze(label, axis=1)
- if not pred.shape == label.shape:
- raise ValueError('Shape of `pred` and `label should be equal, '
- 'but there are {} and {}.'.format(pred.shape,
- label.shape))
- pred_area = []
- label_area = []
- intersect_area = []
- mask = label != ignore_index
-
- for i in range(num_classes):
- pred_i = paddle.logical_and(pred == i, mask)
- label_i = label == i
- intersect_i = paddle.logical_and(pred_i, label_i)
- pred_area.append(paddle.sum(paddle.cast(pred_i, "int64")))
- label_area.append(paddle.sum(paddle.cast(label_i, "int64")))
- intersect_area.append(paddle.sum(paddle.cast(intersect_i, "int64")))
-
- pred_area = paddle.stack(pred_area)
- label_area = paddle.stack(label_area)
- intersect_area = paddle.stack(intersect_area)
+ if not use_multilabel:
+ if len(pred.shape) == 4:
+ pred = paddle.squeeze(pred, axis=1)
+ if len(label.shape) == 4:
+ label = paddle.squeeze(label, axis=1)
+ if not pred.shape == label.shape:
+ raise ValueError('Shape of `pred` and `label should be equal, '
+ 'but there are {} and {}.'.format(pred.shape,
+ label.shape))
+ pred_area = []
+ label_area = []
+ intersect_area = []
+ mask = label != ignore_index
+
+ for i in range(num_classes):
+ pred_i = paddle.logical_and(pred == i, mask)
+ label_i = label == i
+ intersect_i = paddle.logical_and(pred_i, label_i)
+ pred_area.append(paddle.sum(paddle.cast(pred_i, "int64")))
+ label_area.append(paddle.sum(paddle.cast(label_i, "int64")))
+ intersect_area.append(paddle.sum(paddle.cast(intersect_i, "int64")))
+
+ pred_area = paddle.stack(pred_area)
+ label_area = paddle.stack(label_area)
+ intersect_area = paddle.stack(intersect_area)
+ else:
+ pred_area = pred.sum([0, 2, 3]).astype('int64')
+ label_area = label.sum([0, 2, 3]).astype('int64')
+ intersect_area = (pred * label).sum([0, 2, 3]).astype('int64')
return intersect_area, pred_area, label_area
diff --git a/paddleseg/utils/visualize.py b/paddleseg/utils/visualize.py
index 27211c4113..d6e5842ff1 100644
--- a/paddleseg/utils/visualize.py
+++ b/paddleseg/utils/visualize.py
@@ -19,7 +19,7 @@
from PIL import Image as PILImage
-def visualize(image, result, color_map, save_dir=None, weight=0.6):
+def visualize(image, result, color_map, save_dir=None, weight=0.6, use_multilabel=False):
"""
Convert predict result to color image, and save added image.
@@ -29,6 +29,7 @@ def visualize(image, result, color_map, save_dir=None, weight=0.6):
color_map (list): The color used to save the prediction results.
save_dir (str): The directory for saving visual image. Default: None.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
@@ -36,14 +37,29 @@ def visualize(image, result, color_map, save_dir=None, weight=0.6):
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = np.array(color_map).astype("uint8")
- # Use OpenCV LUT for color mapping
- c1 = cv2.LUT(result, color_map[:, 0])
- c2 = cv2.LUT(result, color_map[:, 1])
- c3 = cv2.LUT(result, color_map[:, 2])
- pseudo_img = np.dstack((c3, c2, c1))
im = cv2.imread(image)
- vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
+ if not use_multilabel:
+ # Use OpenCV LUT for color mapping
+ c1 = cv2.LUT(result, color_map[:, 0])
+ c2 = cv2.LUT(result, color_map[:, 1])
+ c3 = cv2.LUT(result, color_map[:, 2])
+ pseudo_img = np.dstack((c3, c2, c1))
+
+ vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
+ else:
+ vis_result = im.copy()
+ for i in range(result.shape[0]):
+ mask = result[i]
+ c1 = np.where(mask, color_map[i, 0], vis_result[..., 0])
+ c2 = np.where(mask, color_map[i, 1], vis_result[..., 1])
+ c3 = np.where(mask, color_map[i, 2], vis_result[..., 2])
+ pseudo_img = np.dstack((c3, c2, c1)).astype('uint8')
+
+ contour, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ vis_result = cv2.addWeighted(vis_result, weight, pseudo_img, 1 - weight, 0)
+ contour_color = (int(color_map[i, 0]), int(color_map[i, 1]), int(color_map[i, 2]))
+ vis_result = cv2.drawContours(vis_result, contour, -1, contour_color, 1)
if save_dir is not None:
if not os.path.exists(save_dir):
@@ -55,7 +71,7 @@ def visualize(image, result, color_map, save_dir=None, weight=0.6):
return vis_result
-def get_pseudo_color_map(pred, color_map=None):
+def get_pseudo_color_map(pred, color_map=None, use_multilabel=False):
"""
Get the pseudo color image.
@@ -63,10 +79,16 @@ def get_pseudo_color_map(pred, color_map=None):
pred (numpy.ndarray): the origin predicted image.
color_map (list, optional): the palette color map. Default: None,
use paddleseg's default color map.
+ use_multilabel (bool, optional): Whether to enable multilabel mode. Default: False.
Returns:
(numpy.ndarray): the pseduo image.
"""
+ if use_multilabel:
+ bg_pred = (pred.sum(axis=0, keepdims=True) == 0).astype('int32')
+ pred = np.concatenate([bg_pred, pred], axis=0)
+ gray_idx = np.arange(pred.shape[0]).astype(np.uint8)
+ pred = (pred * gray_idx[:, None, None]).sum(axis=0)
pred_mask = PILImage.fromarray(pred.astype(np.uint8), mode='P')
if color_map is None:
color_map = get_color_map_list(256)
diff --git a/tools/data/convert_multilabel.py b/tools/data/convert_multilabel.py
new file mode 100644
index 0000000000..0fe6372030
--- /dev/null
+++ b/tools/data/convert_multilabel.py
@@ -0,0 +1,254 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+File: convert_multilabel.py
+This file is used to convert `uwmgi` or `coco` type dataset to support multi-label dataset format.
+Examples of usage are as follows:
+1. convert UWMGI dataset
+python convert_multilabel.py --dataset_type uwmgi --zip_input ${uwmgi_origin_zip_file} --output ${save_dir} --train_proportion 0.8 --val_proportion 0.2
+2. convert COCO type dataset
+2.1 not yet split training and validation dataset
+python convert_multilabel.py --dataset_type coco --img_input ${img_dir} --ann_input ${ann_dir} --output ${save_dir} --train_proportion 0.8 --val_proportion 0.2
+2.2 training and validation dataset split
+python convert_multilabel.py --dataset_type coco --img_input ${train_img_dir} --ann_input ${train_ann_dir} --output ${save_dir} --train_proportion 1.0 --val_proportion 0.0
+python convert_multilabel.py --dataset_type coco --img_input ${val_img_dir} --ann_input ${val_ann_dir} --output ${save_dir} --train_proportion 0.0 --val_proportion 1.0
+"""
+
+import argparse
+import os
+import random
+import zipfile
+
+import cv2
+import numpy as np
+import pandas as pd
+from PIL import Image
+from pycocotools.coco import COCO
+from tqdm import tqdm
+
+
+def uwmgi_get_image(fp):
+ image = np.array(Image.open(fp))
+ image = image.astype('float32')
+ image = image / np.max(image) * 255
+ image = np.tile(image[..., None], [1, 1, 3])
+ image = image.astype('uint8')
+ return image
+
+
+def uwmgi_get_image_id(image_filepath):
+ image_dirs = image_filepath.replace('/', '\\').split('\\')
+ image_dirs = [image_dirs[2]] + image_dirs[4].split('_')[:2]
+ image_id = '_'.join(image_dirs)
+ return image_id
+
+
+def uwmgi_rle_decode(mask_rle, image_shape):
+ s = mask_rle.split()
+ starts, lengths = [np.asarray(x, dtype=int)
+ for x in (s[0:][::2], s[1:][::2])]
+ starts -= 1
+ ends = starts + lengths
+ img = np.zeros(image_shape[0] * image_shape[1], dtype='uint8')
+ for low, high in zip(starts, ends):
+ img[low:high] = 1
+ return img.reshape(image_shape)
+
+
+def uwmgi_to_multilabel_format(args):
+ with zipfile.ZipFile(args.input, 'r') as zip_fp:
+ total_df = pd.read_csv(zip_fp.open('train.csv', 'r'))
+
+ total_image_namelist = []
+ for name in zip_fp.namelist():
+ if os.path.splitext(name)[1] == '.png':
+ total_image_namelist.append(name)
+ train_image_namelist = random.sample(
+ total_image_namelist, int(
+ len(total_image_namelist) * args.train_proportion))
+ val_image_namelist = np.setdiff1d(
+ total_image_namelist, train_image_namelist)
+
+ pbar = tqdm(total=len(total_image_namelist))
+ for image_namelist, split in zip(
+ [train_image_namelist, val_image_namelist], ['train', 'val']):
+ txt_lines = []
+ for image_name in image_namelist:
+ with zip_fp.open(image_name, 'r') as fp:
+ image = uwmgi_get_image(fp)
+ image_id = uwmgi_get_image_id(image_name)
+ anns = total_df[total_df['id'] == image_id]
+ height, width = image.shape[:2]
+ mask = np.zeros([height, width * 3], dtype='uint8')
+ for _, ann in anns.iterrows():
+ if not pd.isna(ann['segmentation']):
+ if ann['class'] == 'large_bowel':
+ mask[:, 0:width] = uwmgi_rle_decode(
+ ann['segmentation'], (height, width))
+ elif ann['class'] == 'small_bowel':
+ mask[:, width:width * 2] = uwmgi_rle_decode(
+ ann['segmentation'], (height, width))
+ else: # ann['class'] == 'stomach'
+ mask[:, width * 2:] = uwmgi_rle_decode(
+ ann['segmentation'], (height, width))
+ cv2.imwrite(os.path.join(
+ args.output, 'images', split, image_id + '.jpg'), image)
+ cv2.imwrite(os.path.join(
+ args.output, 'annotations', split, image_id + '.png'), mask)
+ txt_lines.append(
+ os.path.join('images', split, image_id + '.jpg')
+ + ' ' + os.path.join('annotations', split, image_id + '.png'))
+ pbar.update()
+
+ with open(os.path.join(args.output, split + '.txt'), 'w') as fp:
+ fp.write('\n'.join(txt_lines))
+
+
+def coco_to_multilabel_format(args):
+ coco = COCO(args.ann_input)
+ cat_id_map = {
+ old_cat_id: new_cat_id
+ for new_cat_id, old_cat_id in enumerate(coco.getCatIds())
+ }
+ num_classes = len(list(cat_id_map.keys()))
+
+ assert 'annotations' in coco.dataset, \
+ 'Annotation file: {} does not contains ground truth!!!'.format(args.ann_input)
+
+ total_img_id_list = sorted(list(coco.imgToAnns.keys()))
+ train_img_id_list = random.sample(
+ total_img_id_list, int(len(total_img_id_list) * args.train_proportion))
+ val_img_id_list = np.setdiff1d(total_img_id_list, train_img_id_list)
+
+ pbar = tqdm(total=len(total_img_id_list))
+ for img_id_list, split in zip(
+ [train_img_id_list, val_img_id_list], ['train', 'val']):
+ txt_lines = []
+ for img_id in img_id_list:
+ img_info = coco.loadImgs([img_id])[0]
+ img_filename = img_info['file_name']
+ img_w = img_info['width']
+ img_h = img_info['height']
+
+ img_filepath = os.path.join(args.img_input, img_filename)
+ if not os.path.exists(img_filepath):
+ print('Illegal image file: {}, '
+ 'and it will be ignored'.format(img_filepath))
+ continue
+
+ if img_w < 0 or img_h < 0:
+ print('Illegal width: {} or height: {} in annotation, '
+ 'and im_id: {} will be ignored'.format(img_w, img_h, img_id))
+ continue
+
+ ann_ids = coco.getAnnIds(imgIds=[img_id])
+ anns = coco.loadAnns(ann_ids)
+
+ mask = np.zeros([img_h, num_classes * img_w], dtype='uint8')
+ for ann in anns:
+ cat_id = cat_id_map[ann['category_id']]
+ one_cls_mask = coco.annToMask(ann)
+ mask[:, cat_id * img_w: (cat_id + 1) * img_w] = np.where(
+ one_cls_mask, one_cls_mask,
+ mask[:, cat_id * img_w: (cat_id + 1) * img_w])
+
+ image = cv2.imread(img_filepath, cv2.IMREAD_COLOR)
+ cv2.imwrite(os.path.join(
+ args.output, 'images', split,
+ os.path.splitext(img_filename)[0] + '.jpg'), image)
+ cv2.imwrite(os.path.join(
+ args.output, 'annotations', split,
+ os.path.splitext(img_filename)[0] + '.png'), mask)
+ txt_lines.append(os.path.join(
+ 'images', split, os.path.splitext(img_filename)[0] + '.jpg')
+ + ' ' + os.path.join(
+ 'annotations', split, os.path.splitext(img_filename)[0] + '.png'))
+ pbar.update()
+
+ with open(os.path.join(args.output, split + '.txt'), 'w') as fp:
+ fp.write('\n'.join(txt_lines))
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--dataset_type',
+ help='the type of dataset, can be `uwmgi` or `coco`',
+ type=str)
+ parser.add_argument(
+ "--zip_input",
+ help="the directory of original dataset zip file",
+ type=str)
+ parser.add_argument(
+ "--img_input",
+ help="the directory of original dataset image file",
+ type=str)
+ parser.add_argument(
+ "--ann_input",
+ help="the directory of original dataset annotation file",
+ type=str)
+ parser.add_argument(
+ "--output",
+ help="the directory to save converted dataset",
+ type=str)
+ parser.add_argument(
+ '--train_proportion',
+ help='the proportion of train dataset',
+ type=float,
+ default=0.8)
+ parser.add_argument(
+ '--val_proportion',
+ help='the proportion of validation dataset',
+ type=float,
+ default=0.2)
+ args = parser.parse_args()
+
+ assert args.dataset_type in ['uwmgi', 'coco'], \
+ "Now only support the `uwmgi` and `coco`!!!"
+
+ assert 0 <= args.train_proportion <= 1
+ assert 0 <= args.val_proportion <= 1
+ assert args.train_proportion + args.val_proportion == 1
+
+ if not os.path.exists(args.output):
+ os.makedirs(args.output, exist_ok=True)
+
+ os.makedirs(os.path.join(args.output, 'images/train'), exist_ok=True)
+ os.makedirs(os.path.join(args.output, 'annotations/train'), exist_ok=True)
+ os.makedirs(os.path.join(args.output, 'images/val'), exist_ok=True)
+ os.makedirs(os.path.join(args.output, 'annotations/val'), exist_ok=True)
+
+ if args.dataset_type == 'uwmgi':
+ assert os.path.exists(args.zip_input), \
+ f"The directory({args.zip_input}) of " \
+ f"original UWMGI dataset does not exist!"
+ assert zipfile.is_zipfile(args.input)
+
+ uwmgi_to_multilabel_format(args)
+
+ else: # args.dataset_type == 'coco'
+ assert os.path.exists(args.img_input), \
+ f"The directory({args.img_input}) of " \
+ f"original image file does not exist!"
+ assert os.path.exists(args.ann_input), \
+ f"The directory({args.ann_input}) of " \
+ f"original annotation file does not exist!"
+
+ coco_to_multilabel_format(args)
+
+ print("Dataset converts success, the data path: {}".format(args.output))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/predict.py b/tools/predict.py
index 09302d2052..9e1016cc5c 100644
--- a/tools/predict.py
+++ b/tools/predict.py
@@ -98,6 +98,13 @@ def parse_args():
help='Save images with a custom color map. Default: None, use paddleseg\'s default color map.',
type=int)
+ # Set multi-label mode
+ parser.add_argument(
+ '--use_multilabel',
+ action='store_true',
+ default=False,
+ help='Whether to enable multilabel mode. Default: False.')
+
return parser.parse_args()
@@ -118,6 +125,8 @@ def merge_test_config(cfg, args):
test_config['stride'] = args.stride
if args.custom_color:
test_config['custom_color'] = args.custom_color
+ if args.use_multilabel:
+ test_config['use_multilabel'] = args.use_multilabel
return test_config
diff --git a/tools/train.py b/tools/train.py
index 09d864499a..b9ce6cf7af 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -124,6 +124,12 @@ def parse_args():
)
parser.add_argument(
'--opts', help='Update the key-value pairs of all options.', nargs='+')
+ # Set multi-label mode
+ parser.add_argument(
+ '--use_multilabel',
+ action='store_true',
+ default=False,
+ help='Whether to enable multilabel mode. Default: False.')
return parser.parse_args()
@@ -145,6 +151,12 @@ def main(args):
utils.set_device(args.device)
utils.set_cv2_num_threads(args.num_workers)
+ if args.use_multilabel:
+ if 'test_config' not in cfg.dic:
+ cfg.dic['test_config'] = {'use_multilabel': True}
+ else:
+ cfg.dic['test_config']['use_multilabel'] = True
+
# TODO refactor
# Only support for the DeepLabv3+ model
if args.data_format == 'NHWC':
diff --git a/tools/val.py b/tools/val.py
index 454737608b..2ce837972f 100644
--- a/tools/val.py
+++ b/tools/val.py
@@ -97,6 +97,12 @@ def parse_args():
help='Update the key-value pairs of all options.',
default=None,
nargs='+')
+ # Set multi-label mode
+ parser.add_argument(
+ '--use_multilabel',
+ action='store_true',
+ default=False,
+ help='Whether to enable multilabel mode. Default: False.')
return parser.parse_args()
@@ -112,6 +118,8 @@ def merge_test_config(cfg, args):
test_config['is_slide'] = args.is_slide
test_config['crop_size'] = args.crop_size
test_config['stride'] = args.stride
+ if args.use_multilabel:
+ test_config['use_multilabel'] = args.use_multilabel
return test_config