-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No evaluation codes #5
Comments
Hi authors, Thanks for the great work! Thanks a lot! |
Hi, Evaluation Pipeline for PascalVOC. from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchmetrics.classification import MulticlassJaccardIndex
from einops import rearrange
import gem
class ZeroShotSegmentation(torch.nn.Module):
def __init__(self, model, tokenizer, model_name, patch_size=16,device='cpu'):
super(ZeroShotSegmentation, self).__init__()
self.model_name = model_name
self.device = device
self.gem_model = model
self.gem_model.to(device)
self.gem_model.eval()
self.patch_size = patch_size
self.tokenizer = tokenizer
# @staticmethod
def _get_text_embedding(self, classes: list):
prompts = [f'a photo of a {cls}.' for cls in classes]
tokenized_prompts = self.tokenizer(prompts).to(self.device)
text_embedding = self.gem_model.model.encode_text(tokenized_prompts)
text_embedding = F.normalize(text_embedding, dim=-1)
return text_embedding.unsqueeze(0)
def inference(self, image, text_embedding, mask_shape):
B, _, H, W = image.shape
# forward images
feat_gem, feat_ori = self.gem_model.model.visual(image)
feat_gem = F.normalize(feat_gem, dim=-1)
# Patch/Text similarity
logits_gem = 100.0 * feat_gem[:, 1:] @ text_embedding.transpose(1, 2)
logits_gem = rearrange(logits_gem, 'b (h w) c -> b c h w', h=H // self.patch_size, w=W // self.patch_size)
# Interpolate
logits_gem = F.interpolate(logits_gem, size=mask_shape, mode='bilinear')
# Segmentation prediction
pred_gem = logits_gem.argmax(1) + 1
return pred_gem, logits_gem
@torch.no_grad()
def eval_dataset(self, dataloader, classes, device):
text_embedding = self._get_text_embedding(classes=classes[1:]) # remove background class
threshold = 0.85
metric_iou = MulticlassJaccardIndex(num_classes=len(classes), ignore_index=-1).to('cpu')
for i, (image, mask) in enumerate(tqdm(dataloader)):
image, mask = image.to(device), mask#.to(device)
# pred_gem: [batch, W, H] | pred_logits_gem: [batch, num_class, W, H]
pred_gem, pred_logits_gem = self.inference(image, text_embedding, mask.shape[-2:])
# keep the highest probability for each pixel
logits_soft_max_gem = pred_logits_gem.softmax(dim=1).max(dim=1)[0] # 1 x H x W
# clone argmaxed prediction
pred_th_gem = pred_gem.clone()
# apply threshold
pred_th_gem[logits_soft_max_gem < threshold] = 0 # replace values under the threshold with the background class
# Compute the IoU
metric_iou(pred_th_gem.cpu(), mask)
if i%20 == 0:
print(metric_iou.compute().item() * 100)
metric_th_gem = 100 * metric_iou.compute().item()
print(f'mIoU: {metric_th_gem}')
return metric_th_gem
def main(model_name, device, pretrained, patch_size=16, root_path_voc='', batch_size=1):
# # Select Dataset
if batch_size > 1:
resize_mask = True
else:
resize_mask = False
dataset = PascalVOC(root=root_path_voc, split='val',
transform=SegmentationTransforms((448, 448), resize_mask=resize_mask),
aug=False, only_image=False, only_mask=False, ignore_index=-1)
test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8)
# # Model
model = gem.create_gem_model(model_name=model_name, pretrained=pretrained)
tokenizer = gem.get_tokenizer(model_name=model_name)
# # Evaluator
zero_shot_evaluator = ZeroShotSegmentation(model=model, device=device, patch_size=patch_size,
model_name=model_name, tokenizer=tokenizer)
miou_list_cs = zero_shot_evaluator.eval_dataset(dataloader=test_loader,
classes=list(dataset.CLASSES),
device=device,
)
return miou_list_cs
if __name__ == '__main__':
from segmentation_datasets.pascal_voc import PascalVOC, SegmentationTransforms
patch_size = 16
model_name = 'ViT-B-16-quickgelu'
pretrained = 'metaclip_400m'
root_path_voc = ‘/path/to/PascalVOC/'
print('########################################')
print(f'model: {model_name} | pretrained: {pretrained} ')
print('########################################')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main(model_name=model_name, pretrained=pretrained, device=device, patch_size=patch_size, root_path_voc=root_path_voc) Here is the dataset implementation: from os.path import join
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import transforms
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
class PascalVOC(Dataset):
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')
PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]], dtype=torch.uint8)
def __init__(self,
root,
split='train',
transform=None,
only_image=False,
aug=True,
nclass=None,
only_mask=False,
split_file=None,
ignore_index=-1,
return_path=False):
super(PascalVOC, self).__init__()
self.nclass = nclass if nclass is not None else self.PALETTE.shape[0]
self.only_image = only_image
self.only_mask = only_mask
self.split = split
self.return_path = return_path
self.ignore_index = ignore_index
assert self.split in ['train', 'trainval', 'val'], f'{self.split} must be in ["train", "trainval", "val"]'
self.split = 'trainaug' if aug and (self.split == 'train') else self.split
self.root = join(root, 'VOCdevkit/VOC2012/') if split_file is None else root
self.transform = transform
self.anno_type = 'SegmentationClassAug' if aug else 'SegmentationClass'
txt_file = join(self.root, split_file) if split_file is not None \
else join(self.root, 'ImageSets', 'Segmentation', self.split + '.txt')
self.samples = []
with open(txt_file) as f:
samples_tmp = f.readlines()
samples_tmp = list(map(lambda elem: elem.strip(), samples_tmp))
self.samples.extend(samples_tmp)
samples_list = []
self.image_files = []
self.label_files = []
for sample in self.samples:
if split_file is not None:
img = f'{str(sample)}.jpg'
label = f'{str(sample)}.png'
else:
img = f'JPEGImages/{str(sample)}.jpg'
label = f'{self.anno_type}/{str(sample)}.png'
self.image_files.append(join(self.root, img))
self.label_files.append(join(self.root, label))
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = self.image_files[idx]
label_path = self.label_files[idx]
img, msk = Image.open(image_path).convert("RGB"), Image.open(label_path).convert("RGB")
# if self.img_transform is not None:
images, rgb_target = self.transform(img, msk)
h, w = rgb_target.shape[1:]
one_hot_seg_mask = self.ignore_index * torch.ones((h, w), dtype=torch.long)
for color_idx in range(self.nclass):
idx = (rgb_target == self.PALETTE[color_idx].unsqueeze(-1).unsqueeze(-1))
valid_idx = (idx.sum(0) == 3)#.unsqueeze(0)
one_hot_seg_mask[valid_idx] = color_idx
if self.return_path:
path_to_img_msk = {}
path_to_img_msk["img_path"] = image_path
path_to_img_msk["label_path"] = label_path
return images, one_hot_seg_mask, path_to_img_msk
return images, one_hot_seg_mask
class ToTensorMask(nn.Module):
def __init__(self):
super(ToTensorMask, self).__init__()
def forward(self, mask):
return torch.as_tensor(np.array(mask), dtype=torch.int64).permute(2, 0, 1)
class SegmentationTransforms(object):
def __init__(self, size, img_transforms=None, resize_mask=False):
self.img_transforms = img_transforms if img_transforms is not None else transforms.Compose([
transforms.Resize(size=size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
])
self.mask_transforms = transforms.Compose([
transforms.Resize(size=size) if resize_mask else nn.Identity(),
ToTensorMask(),
])
def __call__(self, image, mask):
return self.img_transforms(image), self.mask_transforms(mask)
if __name__ == '__main__':
root = '/path/to/PascalVOC/'
dataset = PascalVOC(root=root, split='train', transform=SegmentationTransforms((448, 448), resize_mask=False),
aug=False, only_image=False, only_mask=False)
test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)
for img, mask in test_loader:
print(img.shape)
print(mask.shape) You will also need to install the torchmetrics library via |
Hi, I am now able to reproduce your result for VOC. Thanks a lot for your reply!! I am also interested in the different behaviors of models pre-training with single or multiple objectives, i.e. CLIP and BLIP. Do you mind sharing how your method can be implemented with BLIP as well? Thanks again! |
Hi,
Thanks to the open source codes of GEM. But, I cannot reproduce the mIoU scores on Pascal VOC, Pascal Context, ADE20K, and OpenImages30K, which are reported in manuscript of CVPR 2024. I would like to ask if the author would be convenient to open the evaluation code?
The text was updated successfully, but these errors were encountered: