-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
79 lines (69 loc) · 2.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# many are copied from https://github.com/mattneary/attention/blob/master/attention/attention.py
# here it nullifies the attention over the first token (<bos>)
# which in practice we find to be a good idea
from io import BytesIO
from PIL import Image
import requests
import torch
import numpy as np
import cv2
def aggregate_llm_attention(attn):
'''Extract average attention vector'''
avged = []
for layer in attn:
layer_attns = layer.squeeze(0)
attns_per_head = layer_attns.mean(dim=0)
vec = torch.concat((
# We zero the first entry because it's what's called
# null attention (https://aclanthology.org/W19-4808.pdf)
torch.tensor([0.]),
# usually there's only one item in attns_per_head but
# on the first generation, there's a row for each token
# in the prompt as well, so take [-1]
attns_per_head[-1][1:].cpu(),
# attns_per_head[-1].cpu(),
# add zero for the final generated token, which never
# gets any attention
torch.tensor([0.]),
))
avged.append(vec / vec.sum())
return torch.stack(avged).mean(dim=0)
def aggregate_vit_attention(attn, select_layer=-2, all_prev_layers=True):
'''Assuming LLaVA-style `select_layer` which is -2 by default'''
if all_prev_layers:
avged = []
for i, layer in enumerate(attn):
if i > len(attn) + select_layer:
break
layer_attns = layer.squeeze(0)
attns_per_head = layer_attns.mean(dim=0)
vec = attns_per_head[1:, 1:].cpu() # the first token is <CLS>
avged.append(vec / vec.sum(-1, keepdim=True))
return torch.stack(avged).mean(dim=0)
else:
layer = attn[select_layer]
layer_attns = layer.squeeze(0)
attns_per_head = layer_attns.mean(dim=0)
vec = attns_per_head[1:, 1:].cpu()
return vec / vec.sum(-1, keepdim=True)
def heterogenous_stack(vecs):
'''Pad vectors with zeros then stack'''
max_length = max(v.shape[0] for v in vecs)
return torch.stack([
torch.concat((v, torch.zeros(max_length - v.shape[0])))
for v in vecs
])
def load_image(image_path_or_url):
if image_path_or_url.startswith('http://') or image_path_or_url.startswith('https://'):
response = requests.get(image_path_or_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_path_or_url).convert('RGB')
return image
def show_mask_on_image(img, mask):
img = np.float32(img) / 255
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_HSV)
hm = np.float32(heatmap) / 255
cam = hm + np.float32(img)
cam = cam / np.max(cam)
return np.uint8(255 * cam), heatmap