forked from sooyekim/3DSRnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.m
More file actions
50 lines (43 loc) · 1.46 KB
/
train.m
File metadata and controls
50 lines (43 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
function train(varargin)
clear mex_conv3d;
% set scale
scale = 2;
% load data & label
data = load(sprintf('./data/train/LR_x%d.mat', scale));
label = load(sprintf('./data/train/HR_x%d.mat', scale));
imdb.images.data = data.LR;
imdb.images.label = label.HR;
imdb.images.set = cat(2,ones(1, size(data.LR, 5)-500), 2*ones(1, 500));
% set CNN model
network = net(scale);
% set the learning rate and weight decay for biases
% default values are used for filters
for i = 2:2:12
network.params(i).learningRate = 0.1;
network.params(i).weightDecay = 0;
end
network.conserveMemory = true;
% options
opts.solver=@adam;
opts.train.batchSize = 32;
opts.train.continue = false;
opts.train.gpus = 1;
opts.train.prefetch = false ;
opts.train.expDir = sprintf('./net/net_x%d', scale);
opts.train.learningRate = [1e-4*ones(1,700) 1e-5*ones(1,100) 1e-6*ones(1,100)];
opts.train.weightDecay = 0.0005;
opts.train.numEpochs = numel(opts.train.learningRate) ;
opts.train.derOutputs = {'objective',1} ;
[opts, ~] = vl_argparse(opts.train, varargin) ;
%record
if(~isdir(opts.expDir))
mkdir(opts.expDir);
end
% Call training function
[network,info] = cnn_train_dag(network, imdb, @getBatch,opts) ;
function inputs = getBatch(imdb, batch,opts)
images = imdb.images.data(:,:,:,:,batch) ;
labels = imdb.images.label(:,:,:,:,batch) ;
images = single(images)/255;
labels = single(labels)/255;
inputs = {'input',gpuArray(images),'label', gpuArray(labels)} ;