forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 10
/
FireModule.lua
47 lines (39 loc) · 1.51 KB
/
FireModule.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
--[[
Fire module as explained in SqueezeNet http://arxiv.org/pdf/1602.07360v1.pdf.
--]]
--FIXME works only for batches.
local FireModule, Parent = torch.class('nn.FireModule', 'nn.Decorator')
function FireModule:__init(nInputPlane, s1x1, e1x1, e3x3, activation)
self.nInputPlane = nInputPlane
self.s1x1 = s1x1
self.e1x1 = e1x1
self.e3x3 = e3x3
self.activation = activation or 'ReLU'
if self.s1x1 > (self.e1x1 + self.e3x3) then
print('Warning: <FireModule> s1x1 is recommended to be smaller'..
' then e1x1+e3x3')
end
self.module = nn.Sequential()
self.squeeze = nn.SpatialConvolution(nInputPlane, s1x1, 1, 1)
self.expand = nn.Concat(2)
self.expand:add(nn.SpatialConvolution(s1x1, e1x1, 1, 1))
self.expand:add(nn.SpatialConvolution(s1x1, e3x3, 3, 3, 1, 1, 1, 1))
-- Fire Module
self.module:add(self.squeeze)
self.module:add(nn[self.activation]())
self.module:add(self.expand)
self.module:add(nn[self.activation]())
Parent.__init(self, self.module)
end
--[[
function FireModule:type(type, tensorCache)
assert(type, 'Module: must provide a type to convert to')
self.module = nn.utils.recursiveType(self.module, type, tensorCache)
end
--]]
function FireModule:__tostring__()
return string.format('%s inputPlanes: %d -> Squeeze Planes: %d -> '..
'Expand: %d(1x1) + %d(3x3), activation: %s',
torch.type(self), self.nInputPlane, self.s1x1,
self.e1x1, self.e3x3, self.activation)
end