-
Notifications
You must be signed in to change notification settings - Fork 17
/
carafe.py
79 lines (63 loc) · 2.76 KB
/
carafe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch import nn
from torch.nn import functional as F
class ConvBNReLU(nn.Module):
'''Module for the Conv-BN-ReLU tuple.'''
def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation,
use_relu=True):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
c_in, c_out, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(c_out)
if use_relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class CARAFE(nn.Module):
def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3):
""" The unofficial implementation of the CARAFE module.
The details are in "https://arxiv.org/abs/1905.02188".
Args:
c: The channel number of the input and the output.
c_mid: The channel number after compression.
scale: The expected upsample scale.
k_up: The size of the reassembly kernel.
k_enc: The kernel size of the encoder.
Returns:
X: The upsampled feature map.
"""
super(CARAFE, self).__init__()
self.scale = scale
self.comp = ConvBNReLU(c, c_mid, kernel_size=1, stride=1,
padding=0, dilation=1)
self.enc = ConvBNReLU(c_mid, (scale*k_up)**2, kernel_size=k_enc,
stride=1, padding=k_enc//2, dilation=1,
use_relu=False)
self.pix_shf = nn.PixelShuffle(scale)
self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest')
self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale,
padding=k_up//2*scale)
def forward(self, X):
b, c, h, w = X.size()
h_, w_ = h * self.scale, w * self.scale
W = self.comp(X) # b * m * h * w
W = self.enc(W) # b * 100 * h * w
W = self.pix_shf(W) # b * 25 * h_ * w_
W = F.softmax(W, dim=1) # b * 25 * h_ * w_
X = self.upsmp(X) # b * c * h_ * w_
X = self.unfold(X) # b * 25c * h_ * w_
X = X.view(b, c, -1, h_, w_) # b * 25 * c * h_ * w_
X = torch.einsum('bkhw,bckhw->bchw', [W, X]) # b * c * h_ * w_
return X
if __name__ == '__main__':
x = torch.Tensor(1, 16, 24, 24)
carafe = CARAFE(16)
oup = carafe(x)
print(oup.size())