Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No evaluation codes #5

Open
duan-song opened this issue Jun 6, 2024 · 3 comments
Open

No evaluation codes #5

duan-song opened this issue Jun 6, 2024 · 3 comments

Comments

@duan-song
Copy link

duan-song commented Jun 6, 2024

Hi,

Thanks to the open source codes of GEM. But, I cannot reproduce the mIoU scores on Pascal VOC, Pascal Context, ADE20K, and OpenImages30K, which are reported in manuscript of CVPR 2024. I would like to ask if the author would be convenient to open the evaluation code?
微信图片_20240606231652

@letitiabanana
Copy link

Hi authors,

Thanks for the great work!
I tried to add in evaluation code and got mIoU = 15 for VOC dataset, which deviate significantly from the number reported. I believe there must be some discrepency between my reimplementation and your code. Could you please released the code for evaluation?

Thanks a lot!

@WalBouss
Copy link
Owner

Hi,
Thanks for your interest in our work and your feedback!
I don’t have time to cleanup all the evaluation pipelines but here is the one for PascalVOC. I will try to push it as part of the repo whenever I have time:

Evaluation Pipeline for PascalVOC.
Don’t forget to change the path to the PascalVOC dataset (root_path_voc). (You can download the dataset at http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar):

from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchmetrics.classification import MulticlassJaccardIndex
from einops import rearrange
import gem



class ZeroShotSegmentation(torch.nn.Module):
    def __init__(self, model, tokenizer, model_name, patch_size=16,device='cpu'):
        super(ZeroShotSegmentation, self).__init__()

        self.model_name = model_name
        self.device = device

        self.gem_model = model
        self.gem_model.to(device)
        self.gem_model.eval()
        self.patch_size = patch_size
        self.tokenizer = tokenizer


    # @staticmethod
    def _get_text_embedding(self, classes: list):
        prompts = [f'a photo of a {cls}.' for cls in classes]

        tokenized_prompts = self.tokenizer(prompts).to(self.device)

        text_embedding = self.gem_model.model.encode_text(tokenized_prompts)
        text_embedding = F.normalize(text_embedding, dim=-1)
        return text_embedding.unsqueeze(0)

    def inference(self, image, text_embedding, mask_shape):
        B, _, H, W = image.shape
        # forward images
        feat_gem, feat_ori = self.gem_model.model.visual(image)
        feat_gem = F.normalize(feat_gem, dim=-1)

        # Patch/Text similarity
        logits_gem = 100.0 * feat_gem[:, 1:] @ text_embedding.transpose(1, 2)
        logits_gem = rearrange(logits_gem, 'b (h w) c -> b c h w', h=H // self.patch_size, w=W // self.patch_size)
        # Interpolate
        logits_gem = F.interpolate(logits_gem, size=mask_shape, mode='bilinear')

        # Segmentation prediction
        pred_gem = logits_gem.argmax(1) + 1

        return pred_gem, logits_gem

    @torch.no_grad()
    def eval_dataset(self, dataloader, classes, device):
        text_embedding = self._get_text_embedding(classes=classes[1:])  # remove background class

        threshold = 0.85
        metric_iou = MulticlassJaccardIndex(num_classes=len(classes), ignore_index=-1).to('cpu')


        for i, (image, mask) in enumerate(tqdm(dataloader)):
            image, mask = image.to(device), mask#.to(device)

            # pred_gem: [batch, W, H] | pred_logits_gem: [batch, num_class, W, H]
            pred_gem, pred_logits_gem = self.inference(image, text_embedding, mask.shape[-2:])

            # keep the highest probability for each pixel
            logits_soft_max_gem = pred_logits_gem.softmax(dim=1).max(dim=1)[0]  # 1 x H x W
            # clone argmaxed prediction
            pred_th_gem = pred_gem.clone()

            # apply threshold
            pred_th_gem[logits_soft_max_gem < threshold] = 0  # replace values under the threshold with the background class

            # Compute the IoU
            metric_iou(pred_th_gem.cpu(), mask)
            if i%20 == 0:
                print(metric_iou.compute().item() * 100)

        metric_th_gem = 100 * metric_iou.compute().item()
        print(f'mIoU: {metric_th_gem}')

        return metric_th_gem



def main(model_name, device, pretrained, patch_size=16, root_path_voc='', batch_size=1):
    # # Select Dataset
    if batch_size > 1:
        resize_mask = True
    else:
        resize_mask = False
    dataset = PascalVOC(root=root_path_voc, split='val',
                        transform=SegmentationTransforms((448, 448), resize_mask=resize_mask),
                        aug=False, only_image=False, only_mask=False, ignore_index=-1)

    test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    # # Model
    model = gem.create_gem_model(model_name=model_name, pretrained=pretrained)
    tokenizer = gem.get_tokenizer(model_name=model_name)

    # # Evaluator
    zero_shot_evaluator = ZeroShotSegmentation(model=model, device=device, patch_size=patch_size,
                                               model_name=model_name, tokenizer=tokenizer)


    miou_list_cs = zero_shot_evaluator.eval_dataset(dataloader=test_loader,
                                                                   classes=list(dataset.CLASSES),
                                                                   device=device,
                                                                   )
    return miou_list_cs



if __name__ == '__main__':
    from segmentation_datasets.pascal_voc import PascalVOC, SegmentationTransforms

    patch_size = 16
    model_name = 'ViT-B-16-quickgelu'
    pretrained = 'metaclip_400m'
    root_path_voc =/path/to/PascalVOC/'

    print('########################################')
    print(f'model: {model_name} | pretrained: {pretrained} ')
    print('########################################')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    main(model_name=model_name, pretrained=pretrained, device=device, patch_size=patch_size, root_path_voc=root_path_voc)

Here is the dataset implementation:

from os.path import join
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import transforms


OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)


class PascalVOC(Dataset):
    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')

    PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
                           [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
                           [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
                           [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
                           [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]], dtype=torch.uint8)

    def __init__(self,
                 root,
                 split='train',
                 transform=None,
                 only_image=False,
                 aug=True,
                 nclass=None,
                 only_mask=False,
                 split_file=None,
                 ignore_index=-1,
                 return_path=False):
        super(PascalVOC, self).__init__()
        self.nclass = nclass if nclass is not None else self.PALETTE.shape[0]
        self.only_image = only_image
        self.only_mask = only_mask
        self.split = split
        self.return_path = return_path
        self.ignore_index = ignore_index
        assert self.split in ['train', 'trainval', 'val'], f'{self.split} must be in ["train", "trainval", "val"]'
        self.split = 'trainaug' if aug and (self.split == 'train') else self.split
        self.root = join(root, 'VOCdevkit/VOC2012/') if split_file is None else root
        self.transform = transform



        self.anno_type = 'SegmentationClassAug' if aug else 'SegmentationClass'
        txt_file = join(self.root, split_file) if split_file is not None \
            else join(self.root, 'ImageSets', 'Segmentation', self.split + '.txt')

        self.samples = []
        with open(txt_file) as f:
            samples_tmp = f.readlines()
        samples_tmp = list(map(lambda elem: elem.strip(), samples_tmp))
        self.samples.extend(samples_tmp)

        samples_list = []
        self.image_files = []
        self.label_files = []
        for sample in self.samples:
            if split_file is not None:
                img = f'{str(sample)}.jpg'
                label = f'{str(sample)}.png'
            else:
                img = f'JPEGImages/{str(sample)}.jpg'
                label = f'{self.anno_type}/{str(sample)}.png'
            self.image_files.append(join(self.root, img))
            self.label_files.append(join(self.root, label))

    def __len__(self):
        return len(self.image_files)


    def __getitem__(self, idx):

        image_path = self.image_files[idx]
        label_path = self.label_files[idx]

        img, msk = Image.open(image_path).convert("RGB"), Image.open(label_path).convert("RGB")

        # if self.img_transform is not None:
        images, rgb_target = self.transform(img, msk)

        h, w = rgb_target.shape[1:]
        one_hot_seg_mask = self.ignore_index * torch.ones((h, w), dtype=torch.long)
        for color_idx in range(self.nclass):
            idx = (rgb_target == self.PALETTE[color_idx].unsqueeze(-1).unsqueeze(-1))
            valid_idx = (idx.sum(0) == 3)#.unsqueeze(0)
            one_hot_seg_mask[valid_idx] = color_idx

        if self.return_path:
            path_to_img_msk = {}
            path_to_img_msk["img_path"] = image_path
            path_to_img_msk["label_path"] = label_path
            return images, one_hot_seg_mask, path_to_img_msk

        return images, one_hot_seg_mask


class ToTensorMask(nn.Module):
    def __init__(self):
        super(ToTensorMask, self).__init__()

    def forward(self, mask):
        return torch.as_tensor(np.array(mask), dtype=torch.int64).permute(2, 0, 1)


class SegmentationTransforms(object):
    def __init__(self, size, img_transforms=None, resize_mask=False):
        self.img_transforms = img_transforms if img_transforms is not None else transforms.Compose([
            transforms.Resize(size=size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
        ])
        self.mask_transforms = transforms.Compose([
            transforms.Resize(size=size) if resize_mask else nn.Identity(),
            ToTensorMask(),
        ])

    def __call__(self, image, mask):
        return self.img_transforms(image), self.mask_transforms(mask)



if __name__ == '__main__':
    root = '/path/to/PascalVOC/'
    dataset = PascalVOC(root=root, split='train', transform=SegmentationTransforms((448, 448), resize_mask=False),
                        aug=False, only_image=False, only_mask=False)

    test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)

    for img, mask in test_loader:
        print(img.shape)
        print(mask.shape)

You will also need to install the torchmetrics library via pip install torchmetrics
Feel free to ask if you have any questions!

@letitiabanana
Copy link

Hi,

I am now able to reproduce your result for VOC. Thanks a lot for your reply!!

I am also interested in the different behaviors of models pre-training with single or multiple objectives, i.e. CLIP and BLIP. Do you mind sharing how your method can be implemented with BLIP as well?

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants