forked from jcjohnson/neural-style
-
Notifications
You must be signed in to change notification settings - Fork 1
/
loadcaffe_wrapper.lua
134 lines (126 loc) · 4.08 KB
/
loadcaffe_wrapper.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
local ffi = require 'ffi'
require 'loadcaffe'
local C = loadcaffe.C
--[[
Most of this function is copied from
https://github.com/szagoruyko/loadcaffe/blob/master/loadcaffe.lua
with some horrible horrible hacks added by Justin Johnson to
make it possible to load VGG-19 without any CUDA dependency.
--]]
local function loadcaffe_load(prototxt_name, binary_name, backend)
local backend = backend or 'nn'
local handle = ffi.new('void*[1]')
-- loads caffe model in memory and keeps handle to it in ffi
local old_val = handle[1]
C.loadBinary(handle, prototxt_name, binary_name)
if old_val == handle[1] then return end
-- transforms caffe prototxt to torch lua file model description and
-- writes to a script file
local lua_name = prototxt_name..'.lua'
-- C.loadBinary creates a .lua source file that builds up a table
-- containing the layers of the network. As a horrible dirty hack,
-- we'll modify this file when backend "nn-cpu" is requested by
-- doing the following:
--
-- (1) Delete the lines that import cunn and inn, which are always
-- at lines 2 and 4
local model = nil
if backend == 'nn-cpu' then
C.convertProtoToLua(handle, lua_name, 'nn')
local lua_name_cpu = prototxt_name..'.cpu.lua'
local fin = assert(io.open(lua_name), 'r')
local fout = assert(io.open(lua_name_cpu, 'w'))
local line_num = 1
while true do
local line = fin:read('*line')
if line == nil then break end
-- Fix for using nin_imagenet_conv.caffemodel
if line:find("inn") then
line = line:gsub("inn", "nn")
end
if line_num ~= 2 and line_num ~= 4 then
fout:write(line, '\n')
end
fout:write(line, '\n')
line_num = line_num + 1
end
fin:close()
fout:close()
model = dofile(lua_name_cpu)
else
if backend == "clnn" then
C.convertProtoToLua(handle, lua_name, 'nn')
local lua_name_opencl = prototxt_name..'.opencl.lua'
local fin = assert(io.open(lua_name), 'r')
local fout = assert(io.open(lua_name_opencl, 'w'))
local line_num = 1
while true do
local line = fin:read('*line')
if line == nil then break end
-- Fix for using nin_imagenet_conv.caffemodel
if line:find("inn") then
print("Changing line: ", line)
line = line:gsub("inn", "nn")
print("To line: ", line)
end
--[[
if line:find("SoftMax") then
print("Changing line: ", line)
line = ""
print("To line: ", line)
end
]]--
if line:find("SpatialAveragePooling") then
print("Changing line: ", line)
-- line = line:gsub("SpatialAveragePooling", "SpatialMaxPooling")
line = line:gsub("%}%)", ":ceil()})")
print("To line: ", line)
end
--[[
-- Hack to replace CUDA libraries with openCL libs.
-- My machine only has an ATI Firepro V3900 so can't run CUDA libs.
--]]--
if line_num > 2 and line_num ~=4 then
fout:write(line, '\n')
elseif line_num == 1 then
-- fout:write("require 'nn'", '\n')
fout:write("require 'clnn'", '\n')
end
line_num = line_num + 1
end
fin:close()
fout:close()
model = dofile(lua_name_opencl)
else
C.convertProtoToLua(handle, lua_name, backend)
model = dofile(lua_name)
end
end
-- goes over the list, copying weights from caffe blobs to torch tensor
local net = nn.Sequential()
local list_modules = model
for i,item in ipairs(list_modules) do
item[2].name = item[1]
if item[2].weight then
local w = torch.FloatTensor()
local bias = torch.FloatTensor()
C.loadModule(handle, item[1], w:cdata(), bias:cdata())
if backend == 'ccn2' then
w = w:permute(2,3,4,1)
end
item[2].weight:copy(w)
item[2].bias:copy(bias)
end
net:add(item[2])
end
C.destroyBinary(handle)
if backend == 'cudnn' or backend == 'ccn2' then
net:cuda()
elseif backend == 'clnn' then
net:cl()
end
return net
end
return {
load = loadcaffe_load
}