-
Notifications
You must be signed in to change notification settings - Fork 1
/
vis.py
79 lines (67 loc) · 2.76 KB
/
vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import cv2
import os
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import warnings
warnings.filterwarnings("ignore")
from segment_anything import sam_model_registry
from model.Network import Network
from option import args
def save_mask(pred, file):
pred = pred.permute(1, 2, 0)
pred_np = pred.detach().cpu().numpy()
pred_np = pred_np * 255
pred_np = pred_np.astype(np.uint8)
cv2.imwrite(file, pred_np)
return
if __name__ == '__main__':
checkpoint_file = "chekpoints/0015200.pth"
image_path = "exp_dir/10.jpg"
raw_img = cv2.imread(image_path)
# load image
image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.cuda()
# gt = Image.open("exp_dir/gt.png").convert('L')
# gt = transforms.Resize((256, 256))(gt)
# gt = transforms.ToTensor()(gt).unsqueeze(0).cuda()
if args.sam == "vit_b":
model_encoder = sam_model_registry["vit_b"](checkpoint="pretrained/sam_vit_b_01ec64.pth")
else:
model_encoder = sam_model_registry["vit_h"](checkpoint="pretrained/sam_vit_h_4b8939.pth")
model_encoder = model_encoder.cuda()
model_encoder.eval()
model = Network(args).cuda()
# import torch.nn as nn
# model = nn.DataParallel(model) # multi-GPU
checkpoint = torch.load(checkpoint_file)
model_state = checkpoint['model_state_dict']
model.load_state_dict(model_state)
model_encoder.load_state_dict(checkpoint['sam_model_state_dict'], strict=False)
model.eval()
images = torch.stack([transforms.Resize(1024)(image) for image in input_batch])
image_embeddings, interm_embeddings = model_encoder.image_encoder(images)
masks, iou_preds, uncertainty_p = model(images, image_embeddings, interm_embeddings, multimask_output=False)
# # U_p + U_i
# ones = torch.ones(iou_preds.shape).cuda()
# confidence = (iou_preds + (ones - uncertainty_p.mean(dim=(2, 3)))) / 2
# # Make sure we don't have any numerical instability
# eps = 1e-12
# pred = torch.clamp(masks, 0. + eps, 1. - eps)
# confidence = torch.clamp(confidence, 0. + eps, 1. - eps)
# pred_new = torch.zeros(pred.shape).cuda()
# for i in range(pred.shape[0]):
# pred_new[i] = confidence[i] * pred[i] + (1 - confidence[i]) * gt[i]
# # polyp_pred = pred_new >= 0.5
polyp_pred = masks >= 0.5
img_name = image_path.split('/')[-1]
polyp_file = os.path.join('exp_dir/', img_name.split('.')[0]+'.png')
save_mask(polyp_pred[0], polyp_file)