-
Notifications
You must be signed in to change notification settings - Fork 11
/
BCNN.py
32 lines (26 loc) · 859 Bytes
/
BCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
class BCNN(nn.Module):
def __init__(self, thresh=1e-8, is_vec=True, input_dim=512):
super(BCNN, self).__init__()
self.thresh = thresh
self.is_vec = is_vec
self.output_dim = input_dim * input_dim
def _bilinearpool(self, x):
batchSize, dim, h, w = x.data.shape
x = x.reshape(batchSize, dim, h * w)
x = 1. / (h * w) * x.bmm(x.transpose(1, 2))
return x
def _signed_sqrt(self, x):
x = torch.mul(x.sign(), torch.sqrt(x.abs() + self.thresh))
return x
def _l2norm(self, x):
x = nn.functional.normalize(x)
return x
def forward(self, x):
x = self._bilinearpool(x)
x = self._signed_sqrt(x)
if self.is_vec:
x = x.view(x.size(0), -1)
x = self._l2norm(x)
return x