-
Notifications
You must be signed in to change notification settings - Fork 16
/
train.lua
141 lines (117 loc) · 4.55 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
----------------------------------------------------------------------
-- Training routine
--
-- Rana Hanocka
----------------------------------------------------------------------
require 'torch' -- torch
require 'optim' -- an optimization package, for online and batch methods
----------------------------------------------------------------------
--[[
1. Setup SGD optimization state and learning rate schedule
2. Create loggers.
3. train - this function handles the high-level training loop,
i.e. load data, train model, save model and state to disk
4. trainBatch - Used by train() to train a single batch after the data is loaded.
]]--
-- Setup a reused optimization state (for sgd). If needed, reload it from disk
print(sys.COLORS.red .. '==> configuring optimizer')
local optimState = {
learningRate = opt.LR,
learningRateDecay = opt.learningRateDecay,
momentum = opt.momentum, -- not needed in ADAM..
weightDecay = opt.weightDecay,
beta1 = opt.beta1,
beta2 = opt.beta2,
epsilon = opt.epsilon
}
if opt.optimState ~= 'none' then
assert(paths.filep(opt.optimState), 'File not found: ' .. opt.optimState)
print('Loading optimState from file: ' .. opt.optimState)
optimState = torch.load(opt.optimState)
end
-- 2. Create loggers.
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
--trainLogger:display(false)
local batchNumber
local loss_epoch
-- 3. train - this function handles the high-level training loop,
-- i.e. load data, train model, save model and state to disk
function train()
print('==> doing epoch on training data:')
print("==> online epoch # " .. epoch)
batchNumber = 0
cutorch.synchronize()
-- set the dropouts to training mode
model:training()
local tm = torch.Timer()
loss_epoch = 0
for i=1,opt.epochSize do
-- queue jobs to data-workers
donkeys:addjob(
-- the job callback (runs in data-worker thread)
function()
local inputs, fieldI, sourceImgs, targetImgs = trainLoader:sample(opt.batchSize)
return inputs, fieldI, sourceImgs, targetImgs
end,
-- the end callback (runs in the main thread)
trainBatch
)
end
donkeys:synchronize()
cutorch.synchronize()
loss_epoch = loss_epoch / (opt.epochSize * opt.batchSize)
trainLogger:add{
['avg loss (train set)'] = loss_epoch
}
print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t'
.. 'average loss (per batch): %.2f \t ',
epoch, tm:time().real, loss_epoch))
print('\n')
collectgarbage()
-- save model
-- clear the intermediate states in the model before saving to disk
-- this saves lots of disk space
model:clearState()
saveDataParallel(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model) -- defined in util.lua
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
end -- of train()
-------------------------------------------------------------------------------------------
-- GPU inputs (preallocate)
local inputs = torch.CudaTensor()
local fieldI = torch.CudaTensor()
local sourceImgs = torch.CudaTensor()
local targetImgs = torch.CudaTensor()
local timer = torch.Timer()
local dataTimer = torch.Timer()
local parameters, gradParameters = model:getParameters()
-- 4. trainBatch - Used by train() to train a single batch after the data is loaded.
function trainBatch(inputsCPU, fieldICPU, sourceImgsCPU, targetImgsCPU)
cutorch.synchronize()
collectgarbage()
local dataLoadingTime = dataTimer:time().real
timer:reset()
-- transfer over to GPU
inputs:resize(inputsCPU:size()):copy(inputsCPU)
fieldI:resize(fieldICPU:size()):copy(fieldICPU)
sourceImgs:resize(sourceImgsCPU:size()):copy(sourceImgsCPU)
targetImgs:resize(targetImgsCPU:size()):copy(targetImgsCPU)
if opt.swap then sourceImgs, targetImgs = targetImgs, sourceImgs end -- needed for 3D (compute forward warping)
local err, outputs
feval = function(x)
model:zeroGradParameters()
outputs = model:forward({sourceImgs, {inputs, fieldI}})
err = criterion:forward(outputs, targetImgs)
local gradOutputs = criterion:backward(outputs, targetImgs)
model:backward({sourceImgs, {inputs, fieldI}}, gradOutputs)
return err, gradParameters
end
optim.adam(feval, parameters, optimState)
cutorch.synchronize()
batchNumber = batchNumber + 1
loss_epoch = loss_epoch + err
-- Print information
print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format(
epoch, batchNumber, opt.epochSize, timer:time().real, err,
optimState.learningRate, dataLoadingTime))
dataTimer:reset()
end