-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfigs.lua
More file actions
206 lines (159 loc) · 6.46 KB
/
configs.lua
File metadata and controls
206 lines (159 loc) · 6.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
--[[
Loads necessary libraries and files for the train script.
]]
local function convert_model_backend(model, opt, is_gpu)
assert(model)
assert(opt)
assert(is_gpu ~= nil)
if opt.GPU >= 1 and is_gpu then
print('Running on GPU: num_gpus = [' .. opt.nGPU .. ']')
require 'cutorch'
require 'cunn'
opt.data_type = 'torch.CudaTensor'
model:cuda()
-- require cudnn if available
if pcall(require, 'cudnn') and not opt.disable_cudnn then
cudnn.convert(model, cudnn):cuda()
cudnn.benchmark = true
if opt.cudnn_deterministic then
model:apply(function(m) if m.setMode then m:setMode(1,1,1) end end)
end
print('Network has', #model:findModules'cudnn.SpatialConvolution', 'cudnn convolutions')
end
else
print('Running on CPU')
opt.data_type = 'torch.FloatTensor'
if pcall(require, 'cudnn') then
cudnn.convert(model, nn)
end
model:float()
end
return model
end
------------------------------------------------------------------------------------------------------------
local function LoadConfigs(model, dataLoadTable, rois, modelParameters, opts)
torch.setdefaulttensortype('torch.FloatTensor')
-------------------------------------------------------------------------------
-- Process command line options
-------------------------------------------------------------------------------
local opt, optimState, optimStateFn, nEpochs
local Options = fastrcnn.Options()
opt = Options:parse(opts or {})
print('Saving everything to: ' .. opt.savedir)
os.execute('mkdir -p ' .. opt.savedir)
if opt.GPU >= 1 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.GPU)
end
-- Training hyperparameters
-- (Some of these aren't relevant for rmsprop which is the optimization we use)
if not optimState then
optimState = {
learningRate = opt.LR,
learningRateDecay = opt.LRdecay,
momentum = opt.momentum,
dampening = 0.0,
weightDecay = opt.weightDecay
}
end
-- define optim state function ()
if type(opt.schedule) == 'table' then
-- setup schedule
local schedule = {}
local schedule_id = 0
for i=1, #opt.schedule do
table.insert(schedule, {schedule_id+1, schedule_id+opt.schedule[i][1], opt.schedule[i][2], opt.schedule[i][3]})
schedule_id = schedule_id+opt.schedule[i][1]
end
optimStateFn = function(epoch)
for k, v in pairs(schedule) do
if v[1] <= epoch and v[2] >= epoch then
return {
learningRate = v[3],
learningRateDecay = opt.LRdecay,
momentum = opt.momentum,
dampening = 0.0,
weightDecay = v[4],
end_schedule = (v[2]==epoch and 1) or 0
}
end
end
return optimState
end
-- determine the maximum number of epochs
for k, v in pairs(schedule) do
nEpochs = v[2]
end
else
optimStateFn = function(epoch) return optimState end
end
-- Random number seed
if opt.manualSeed ~= -1 then
torch.manualSeed(opt.manualSeed)
else
torch.seed()
end
-------------------------------------------------------------------------------
-- Setup criterion
-------------------------------------------------------------------------------
local criterion = nn.ParallelCriterion()
:add(nn.CrossEntropyCriterion(), 1)
:add(nn.BBoxRegressionCriterion(), 1)
-------------------------------------------------------------------------------
-- Continue from snapshot
-------------------------------------------------------------------------------
opt.curr_save_configs = paths.concat(opt.savedir, 'curr_save_configs.t7')
if opt.continue then
if paths.filep(opt.curr_save_configs) then
-- load snapshot configs
local confs = torch.load(opt.curr_save_configs)
opt.bbox_meanstd = confs.bbox_meanstd
opt.epochStart = confs.epoch + 1
-- load model from disk
print('Loading model: ' .. paths.concat(opt.savedir, confs.model_name))
local modelOut = torch.load(paths.concat(opt.savedir, confs.model_name))[1]
modelOut = convert_model_backend(modelOut, opt, true)
criterion:type(opt.data_type)
return opt, modelOut, criterion, optimStateFn, nEpochs
end
end
-------------------------------------------------------------------------------
-- Preprocess rois
-------------------------------------------------------------------------------
do
local nSamples = 1000
print('Compute bbox regression mean/std values over '..nSamples..' train images...')
local tic = torch.tic()
local batchprovider = fastrcnn.BatchROISampler(dataLoadTable.train, rois.train, modelParameters, opt, 'train')
-- compute regression mean/std
opt.bbox_meanstd = batchprovider:setupData(nSamples)
print('Done. Elapsed time: ' .. torch.toc(tic))
print('mean: ', opt.bbox_meanstd.mean)
print('std: ', opt.bbox_meanstd.std)
end
-------------------------------------------------------------------------------
-- Setup model
-------------------------------------------------------------------------------
local modelOut = nn.Sequential()
-- add mean/std norm
model:add(nn.ParallelTable()
:add(nn.Identity())
:add(nn.BBoxNorm(opt.bbox_meanstd.mean, opt.bbox_meanstd.std)))
modelOut:add(model)
-- convert model backend/type
modelOut = convert_model_backend(modelOut, opt, true)
criterion:type(opt.data_type)
if opt.verbose then
print('Network:')
print(model)
end
-- Save options to experiment directory
torch.save(paths.concat(opt.savedir, 'options.t7'), opt)
torch.save(paths.concat(opt.savedir, 'model_parameters.t7'), modelParameters)
collectgarbage()
collectgarbage()
return opt, modelOut, criterion, optimStateFn, nEpochs
end
---------------------------------------------------------------------------------------------------------------------
return LoadConfigs