-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add multi-label semantic segmentation support #3479
[Feature] Add multi-label semantic segmentation support #3479
Conversation
Thanks for your contribution! |
pred = logit | ||
pred = (1 - 2 * label) * pred | ||
pred_neg = pred - label * 1e12 | ||
pred_pos = pred - (1 - label) * 1e12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果使用where和现在的代码哪个更快?可以测试一下常用尺寸的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经测试,loss部分batch小的时候算术方法较快,batch大的时候where占优,还有一个是标签的稀疏性也会影响计算时间
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试的图片尺寸是多少,稀疏性的影响大不大
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尺寸 512 X 512, 当label中有效值1得占比远小于 1/num_classes时,算术占优
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非极端情况下两者时间差别不大
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那还是用where吧,看起来可读性好点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
label (Tensor): Label tensor, the data type is int64. Shape is (N, C), where each | ||
value is 0 or 1, and if shape is more than 2D, this is | ||
(N, C, D1, D2,..., Dk), k >= 1. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
描述似乎不太准确
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更正
paddleseg/transforms/transforms.py
Outdated
data['gt_fields'] = [] | ||
|
||
if self.mode.lower() == 'train': | ||
assert 'instances' in data, ValueError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写出详细报错信息吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已重新调整COCOInstance api的结构
paddleseg/transforms/transforms.py
Outdated
for idx, one_class_label in enumerate(label): | ||
data[f'label_{idx}'] = one_class_label | ||
if self.mode.lower() == 'train': | ||
data['gt_fields'].append(f'label_{idx}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要分开呢?似乎大部分针对label的transform都可以通过修改functional中类似与[a:b, c:d, :] 为 [a:b, c:d, ...] 解决。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已重新调整COCOInstance api的结构
paddleseg/transforms/transforms.py
Outdated
data['gt_fields'].append(f'label_{idx}') | ||
|
||
try: | ||
del data['instances'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用 data.pop('instances', None)吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已重新调整COCOInstance api的结构
paddleseg/utils/metrics.py
Outdated
for i in range(num_classes): | ||
pred_i = pred[:, i] | ||
label_i = label[:, i] | ||
intersect_i = paddle.logical_and(pred_i, label_i.astype(paddle.int32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来似乎不需要使用循环了,上面使用循环是因为类别都拍平在一个h w中了,多标签的情况应该是0-1标签,直接计算就可以吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,已经调整为直接取和
paddleseg/transforms/transforms.py
Outdated
|
||
label = np.concatenate(label, axis=0) | ||
|
||
label[label == self.ignore_index] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥要将ignore_index变为0呢,变为0就表示负类了,而ignore应该表示不参与loss计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经更正
value is 0 or 1, and if shape is more than 2D, this is | ||
(N, C, D1, D2,..., Dk), k >= 1. | ||
""" | ||
label = label.astype(paddle.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需不需要将ignore_index的位置求出来,使其不参与正类与负类的计算?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经更正
paddleseg/transforms/transforms.py
Outdated
""" | ||
def __init__(self, mode="train", use_multilabel=False, ignore_index=255): | ||
self.mode = mode | ||
self.use_multilabel = use_multilabel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我认为把use_multilabel加到dataset合理一点,并在config或者builder中传参,同时需要在config_checker中添加对loss的检查(使用multilabel不支持的损失函数)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重新优化了COCOInstance的逻辑,使用allow_overlap(bool),来控制是否开启多标签模式
@@ -0,0 +1,51 @@ | |||
batch_size: 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或许可以继承singlelabel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经改为继承
60b6a69
to
8fe2fb9
Compare
@Asthestarsfalll 我已经按照建议进行修改,并更新了pr对应分支的内容 |
paddleseg/datasets/coco_instance.py
Outdated
self.num_classes = self.NUM_CLASSES | ||
self.ignore_index = self.IGNORE_INDEX | ||
self.allow_overlap = allow_overlap | ||
self.add_background = add_background |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数的作用是啥,和多标签有关系吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allow_overlop参数得意思是是否允许标注重叠,若允许则为多标签模型,不允许则为单标签模式
add_background参数的意思是是否需要将背景添加为一类
weight=None, | ||
ignore_index=255, | ||
top_k_percent_pixels=1.0, | ||
avg_non_ignore=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有些参数似乎没有用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修正
logit_pos = paddle.where(paddle.logical_and(label, mask), | ||
logit, paddle.to_tensor(float("-inf"))) | ||
logit_neg = paddle.where(paddle.logical_or(label, ~mask), | ||
paddle.to_tensor(float("-inf")), logit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段的可读性比较差,在上一次代码的基础上对logit_pos和logit_neg用where就行了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert len(label.shape) == len(logit.shape)
logit = logit.transpose([0, 2, 3, 1])
logexp_one = paddle.zeros_like(logit[..., :1])
logit_pos = paddle.where((label == 1), logit, paddle.to_tensor(float('inf')))
logit_pos = paddle.concat([logexp_one, logit_pos], axis=-1)
loss_pos = paddle.logsumexp(-logit_pos, axis=-1)
logit_neg = paddle.where((label == 0), logit, paddle.to_tensor(float('-inf')))
logit_neg = paddle.concat([logexp_one, logit_neg], axis=-1)
loss_neg = paddle.logsumexp(logit_neg, axis=-1)
loss = loss_pos + loss_neg
mask = (label != self.ignore_index).astype('float32')
loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,提高可读性
|
||
|
||
@manager.LOSSES.add_component | ||
class MultiLabelAsymmetricLoss(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个loss的参考资料有吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加
paddleseg/utils/metrics.py
Outdated
label_area = label.sum(0).sum(-1).sum(-1).astype("int64") | ||
intersect = paddle.logical_and( | ||
pred.astype("bool"), label.astype("bool")).astype("int64") | ||
intersect_area = intersect.sum(0).sum(-1).sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些地方用一个sum就行了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为这里label的形状为(bs, num_classes, h, w)需要保留num_classes维度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum是支持多个axis的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
还有建议在config_checker中添加对多标签分割检查loss的逻辑 |
super(MultiLabelCategoricalCrossEntropyLoss, self).__init__() | ||
self.ignore_index = ignore_index | ||
self.EPS = 1e-8 | ||
self.data_format = data_format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面data_format的逻辑怎么又去掉了
|
||
logexp_one = paddle.zeros_like(logit[..., :1]) | ||
|
||
logit_pos = paddle.where((label == 1), logit, paddle.to_tensor(float('inf'))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.where应该是支持非布尔矩阵的吧,直接用label就行了吧?下同
assert len(label.shape) == len(logit.shape) | ||
logit = logit.transpose([0, 2, 3, 1]) | ||
|
||
logexp_one = paddle.zeros_like(logit[..., :1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字有点迷惑了,用zero吧
|
||
loss = loss_pos + loss_neg | ||
mask = (label != self.ignore_index).astype('float32') | ||
loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
除以mask mean是为了放缩loss,以消除对忽略这部分值对loss大小的影响,但是前面的步骤并没有对这部分的值进行忽略,应该需要加一个loss=loss*mask
参考一下PR提交规范格式化一下代码 |
9729310
to
25ba1f7
Compare
paddleseg/cvlibs/builder.py
Outdated
@@ -119,7 +119,7 @@ def model(self) -> paddle.nn.Layer: | |||
'No model specified in the configuration file.' | |||
|
|||
if self.config.train_dataset_cfg[ | |||
'type'] not in ['Dataset', 'SegDataset']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本文件中的两处或许不需要修改?为什么不做numclass的一致检查呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在修改为直接使用Dataset API
paddleseg/datasets/coco.py
Outdated
import numpy as np | ||
import pycocotools.coco as cocoAPI | ||
import pycocotools.mask as maskUtils | ||
from paddle.io import Dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle的引用需要放在下一个代码块,和第三方库区分。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在修改为直接使用Dataset API
paddleseg/core/predict.py
Outdated
added_image = utils.visualize.multi_label_visualize( | ||
im_path, pred, color_map, weight=0.6) | ||
added_image_path = os.path.join(added_saved_dir, im_file) | ||
mkdir(added_image_path) | ||
cv2.imwrite(added_image_path, added_image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一段重复度较大,或者使用函数包装下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将多标签模式下的可视化功能与单标签模式融合
train_dataset: | ||
type: COCODataset | ||
image_root: data/UWMGI/images/ | ||
json_file: data/UWMGI/annotations/train.json | ||
add_background: True | ||
use_multilabel: True | ||
transforms: | ||
- type: ResizeStepScaling | ||
min_scale_factor: 0.5 | ||
max_scale_factor: 2.0 | ||
scale_step_size: 0.25 | ||
- type: RandomPaddingCrop | ||
crop_size: [512, 512] | ||
- type: RandomHorizontalFlip | ||
- type: Normalize | ||
mean: [0.0, 0.0, 0.0] | ||
std: [1.0, 1.0, 1.0] | ||
mode: train | ||
|
||
val_dataset: | ||
type: COCODataset | ||
image_root: data/UWMGI/images/ | ||
json_file: data/UWMGI/annotations/val.json | ||
add_background: True | ||
use_multilabel: True | ||
transforms: | ||
- type: Resize | ||
target_size: [2048, 512] | ||
keep_ratio: True | ||
size_divisor: 32 | ||
- type: Normalize | ||
mean: [0.0, 0.0, 0.0] | ||
std: [1.0, 1.0, 1.0] | ||
mode: val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据集部分可以单独出抽离出来放在_base_,正如configs/base/cityscapes.yml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将uwmgi.yml置入_base_
* Install PaddlePaddle and relative environments based on the [installation guide](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/install/pip/linux-pip_en.html). | ||
* Install PaddleSeg based on the [reference](../../../docs/install.md). | ||
* Download the UWMGI dataset and link to PaddleSeg/data. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加怎样准备多标签数据,有必要的话,可以增加转换脚本。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加转换脚本,并在说明文档中列出使用说明
9e0498a
to
0a5adbb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体PR完整度很高了,仅留下了几个小comments。
configs/multilabelseg/README.md
Outdated
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case15_day0_slice_0065.jpg"> | ||
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case122_day18_slice_0092.jpg"> | ||
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case130_day20_slice_0072.jpg"> | ||
</p> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新
configs/multilabelseg/README_cn.md
Outdated
<p align="center"> | ||
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case15_day0_slice_0065.jpg"> | ||
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case122_day18_slice_0092.jpg"> | ||
<img src="https://github.com/MINGtoMING/cache_ppseg_multilabelseg_readme_imgs/tree/main/assets/case130_day20_slice_0072.jpg"> | ||
</p> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新
tools/data/convert_uwmgi.py
Outdated
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本是否可复用呢,对所有特定格式的多标签数据都进行转换?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新脚本,使其支持UWMGI
和主流的COCO类型标注转换为ppseg dataset api支持的格式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
目前paddleseg只支持单标签的语义分割,即图像上的某一空间位置的像素点只能对应一个类别(类比单标签分类),而多标签语义分割在空间维度上来看不同实例间的mask可能会重叠,这就需要图像上的某一空间位置的像素点能同时对应多个类别(类比多标签分类)。但目前关于多标签语义分割的工作较其他视觉任务少得多,故在数据集和模型更加偏向于自定义。为了便于自定义数据集或模型使用我添加了如下模块:
增大图片尺寸、辅助样本可以加速收敛
在UWMGI数据上,利用ppmobileseg训练的精度,尺寸为[512,512],其中第一类为增加的背景类,平均mIOU为82.18: