-
Notifications
You must be signed in to change notification settings - Fork 1
/
probes.py
118 lines (89 loc) · 3.59 KB
/
probes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch as t
class LRProbe(t.nn.Module):
def __init__(self, d_in):
super().__init__()
self.net = t.nn.Sequential(
t.nn.Linear(d_in, 1, bias=False),
t.nn.Sigmoid()
)
def forward(self, x, iid=None):
return self.net(x).squeeze(-1)
def pred(self, x, iid=None):
return self(x).round()
def from_data(acts, labels, lr=0.001, weight_decay=0.1, epochs=1000, device='cpu'):
acts, labels = acts.to(device), labels.to(device)
probe = LRProbe(acts.shape[-1]).to(device)
opt = t.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
for _ in range(epochs):
opt.zero_grad()
loss = t.nn.BCELoss()(probe(acts), labels)
loss.backward()
opt.step()
return probe
def __str__():
return "LRProbe"
@property
def direction(self):
return self.net[0].weight.data[0]
class MMProbe(t.nn.Module):
def __init__(self, direction, covariance=None, inv=None, atol=1e-3):
super().__init__()
self.direction = t.nn.Parameter(direction, requires_grad=False)
if inv is None:
self.inv = t.nn.Parameter(t.linalg.pinv(covariance, hermitian=True, atol=atol), requires_grad=False)
else:
self.inv = t.nn.Parameter(inv, requires_grad=False)
def forward(self, x, iid=False):
if iid:
return t.nn.Sigmoid()(x @ self.inv @ self.direction)
else:
return t.nn.Sigmoid()(x @ self.direction)
def pred(self, x, iid=False):
return self(x, iid=iid).round()
def from_data(acts, labels, atol=1e-3, device='cpu'):
acts, labels
pos_acts, neg_acts = acts[labels==1], acts[labels==0]
pos_mean, neg_mean = pos_acts.mean(0), neg_acts.mean(0)
direction = pos_mean - neg_mean
centered_data = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], 0)
covariance = centered_data.t() @ centered_data / acts.shape[0]
probe = MMProbe(direction, covariance=covariance).to(device)
return probe
def __str__():
return "MMProbe"
def ccs_loss(probe, acts, neg_acts):
p_pos = probe(acts)
p_neg = probe(neg_acts)
consistency_losses = (p_pos - (1 - p_neg)) ** 2
confidence_losses = t.min(t.stack((p_pos, p_neg), dim=-1), dim=-1).values ** 2
return t.mean(consistency_losses + confidence_losses)
class CCSProbe(t.nn.Module):
def __init__(self, d_in):
super().__init__()
self.net = t.nn.Sequential(
t.nn.Linear(d_in, 1, bias=False),
t.nn.Sigmoid()
)
def forward(self, x, iid=None):
return self.net(x).squeeze(-1)
def pred(self, acts, iid=None):
return self(acts).round()
def from_data(acts, neg_acts, labels=None, lr=0.001, weight_decay=0.1, epochs=1000, device='cpu'):
acts, neg_acts = acts.to(device), neg_acts.to(device)
probe = CCSProbe(acts.shape[-1]).to(device)
opt = t.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
for _ in range(epochs):
opt.zero_grad()
loss = ccs_loss(probe, acts, neg_acts)
loss.backward()
opt.step()
if labels is not None: # flip direction if needed
acc = (probe.pred(acts) == labels).float().mean()
if acc < 0.5:
probe.net[0].weight.data *= -1
return probe
def __str__():
return "CCSProbe"
@property
def direction(self):
return self.net[0].weight.data[0]