517 lines
15 KiB
Mathematica
517 lines
15 KiB
Mathematica
|
function [net,stats] = cnn_train_dag(net, imdb, getBatch, varargin)
|
||
|
%CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
|
||
|
% CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
|
||
|
% the DagNN wrapper instead of the SimpleNN wrapper.
|
||
|
|
||
|
% Copyright (C) 2014-16 Andrea Vedaldi.
|
||
|
% All rights reserved.
|
||
|
%
|
||
|
% This file is part of the VLFeat library and is made available under
|
||
|
% the terms of the BSD license (see the COPYING file).
|
||
|
|
||
|
opts.expDir = fullfile('data','exp') ;
|
||
|
opts.continue = true ;
|
||
|
opts.batchSize = 256 ;
|
||
|
opts.numSubBatches = 1 ;
|
||
|
opts.train = [] ;
|
||
|
opts.val = [] ;
|
||
|
opts.gpus = [] ;
|
||
|
opts.prefetch = false ;
|
||
|
opts.numEpochs = 300 ;
|
||
|
opts.learningRate = 0.001 ;
|
||
|
opts.weightDecay = 0.0005 ;
|
||
|
opts.momentum = 0.9 ;
|
||
|
opts.randomSeed = 0 ;
|
||
|
opts.stegoShuffle = false;
|
||
|
opts.computeBNMoment = false;
|
||
|
opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ;
|
||
|
opts.profile = false ;
|
||
|
opts.cudnn = true ;
|
||
|
|
||
|
opts.derOutputs = {'objective', 1} ;
|
||
|
opts.extractStatsFn = @extractStats ;
|
||
|
opts.plotStatistics = true;
|
||
|
opts = vl_argparse(opts, varargin) ;
|
||
|
|
||
|
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
|
||
|
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
|
||
|
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
|
||
|
if isnan(opts.train), opts.train = [] ; end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
% Initialization
|
||
|
% -------------------------------------------------------------------------
|
||
|
|
||
|
evaluateMode = isempty(opts.train) ;
|
||
|
if ~evaluateMode
|
||
|
if isempty(opts.derOutputs)
|
||
|
error('DEROUTPUTS must be specified when training.\n') ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
state.getBatch = getBatch ;
|
||
|
stats = [] ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
% Train and validate
|
||
|
% -------------------------------------------------------------------------
|
||
|
|
||
|
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
|
||
|
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;
|
||
|
|
||
|
start = opts.continue * findLastCheckpoint(opts.expDir) ;
|
||
|
if start >= 1
|
||
|
fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
|
||
|
[net, stats] = loadState(modelPath(start)) ;
|
||
|
end
|
||
|
|
||
|
if ( opts.computeBNMoment == true )
|
||
|
% Validation without using the moving average of BN momemnts
|
||
|
for i = 1:numel(net.layers)
|
||
|
if ( isa( net.layers(i).block, 'dagnn.BatchNorm') )
|
||
|
net.layers(i).block.computeMoment = true;
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
|
||
|
for epoch=start+1:opts.numEpochs
|
||
|
|
||
|
% Set the random seed based on the epoch and opts.randomSeed.
|
||
|
% This is important for reproducibility, including when training
|
||
|
% is restarted from a checkpoint.
|
||
|
|
||
|
rng(epoch + opts.randomSeed) ;
|
||
|
prepareGPUs(opts, epoch == start+1) ;
|
||
|
|
||
|
% Train for one epoch.
|
||
|
state.epoch = epoch ;
|
||
|
state.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
|
||
|
|
||
|
%state.train = opts.train(randperm(numel(opts.train))) ; % shuffle
|
||
|
%state.val = opts.val(randperm(numel(opts.val))) ;
|
||
|
|
||
|
% shuffle
|
||
|
if( opts.stegoShuffle == 1 )
|
||
|
|
||
|
N = numel(opts.train);
|
||
|
M = numel(opts.val);
|
||
|
Lab = max(1, numel(opts.gpus));
|
||
|
|
||
|
% M and N must be even, and multiple Lab
|
||
|
assert( ( rem(N, 2*Lab) == 0 ) & ( rem(M, 2*Lab) == 0 ) );
|
||
|
|
||
|
% state.train(1:2:N) = opts.train(2*randperm(N/2) - 1);
|
||
|
% state.train(2:2:N) = state.train(1:2:N) + 1;
|
||
|
%
|
||
|
% state.val(1:2:M) = opts.val(2*randperm(M/2) - 1);
|
||
|
% state.val(2:2:M) = state.val(1:2:M) + 1;
|
||
|
|
||
|
seq = opts.train(2*randperm(N/2) - 1);
|
||
|
seq = reshape(seq, Lab, N/(2*Lab));
|
||
|
state.train = reshape([seq; seq+1], 1, N);
|
||
|
|
||
|
seq = opts.val(2*randperm(M/2) - 1);
|
||
|
seq = reshape(seq, Lab, M/(2*Lab));
|
||
|
state.val = reshape([seq; seq+1], 1, M);
|
||
|
|
||
|
elseif ( opts.stegoShuffle < 0 )
|
||
|
% for regression task
|
||
|
K = abs( opts.stegoShuffle );
|
||
|
|
||
|
M = numel(opts.train)/K;
|
||
|
seq = K * ( randperm(M) - 1 );
|
||
|
seq = [seq + 1; seq + 2; seq + 3; seq + 4; seq + 5; seq + 6];
|
||
|
seq = reshape(seq, numel(seq), 1);
|
||
|
state.train = opts.train(seq);
|
||
|
|
||
|
|
||
|
N = numel(opts.val)/K;
|
||
|
seq = K * ( randperm(N) - 1 ) ;
|
||
|
seq = [seq + 1; seq + 2; seq + 3; seq + 4; seq + 5; seq + 6];
|
||
|
seq = reshape(seq, numel(seq), 1);
|
||
|
state.val = opts.val(seq);
|
||
|
|
||
|
else
|
||
|
state.train = opts.train(randperm(numel(opts.train))) ;
|
||
|
state.val = opts.val(randperm(numel(opts.val))) ;
|
||
|
end
|
||
|
state.imdb = imdb ;
|
||
|
|
||
|
if numel(opts.gpus) <= 1
|
||
|
[stats.train(epoch),prof] = process_epoch(net, state, opts, 'train') ;
|
||
|
stats.val(epoch) = process_epoch(net, state, opts, 'val') ;
|
||
|
if opts.profile
|
||
|
profview(0,prof) ;
|
||
|
keyboard ;
|
||
|
end
|
||
|
else
|
||
|
savedNet = net.saveobj() ;
|
||
|
spmd
|
||
|
net_ = dagnn.DagNN.loadobj(savedNet) ;
|
||
|
[stats_.train, prof_] = process_epoch(net_, state, opts, 'train') ;
|
||
|
stats_.val = process_epoch(net_, state, opts, 'val') ;
|
||
|
if labindex == 1, savedNet_ = net_.saveobj() ; end
|
||
|
end
|
||
|
net = dagnn.DagNN.loadobj(savedNet_{1}) ;
|
||
|
stats__ = accumulateStats(stats_) ;
|
||
|
stats.train(epoch) = stats__.train ;
|
||
|
stats.val(epoch) = stats__.val ;
|
||
|
if opts.profile
|
||
|
mpiprofile('viewer', [prof_{:,1}]) ;
|
||
|
keyboard ;
|
||
|
end
|
||
|
clear net_ stats_ stats__ savedNet savedNet_ ;
|
||
|
end
|
||
|
|
||
|
% save
|
||
|
if ~evaluateMode
|
||
|
saveState(modelPath(epoch), net, stats) ;
|
||
|
end
|
||
|
|
||
|
if opts.plotStatistics
|
||
|
switchFigure(1) ; clf ;
|
||
|
plots = setdiff(...
|
||
|
cat(2,...
|
||
|
fieldnames(stats.train)', ...
|
||
|
fieldnames(stats.val)'), {'num', 'time'}) ;
|
||
|
for p = plots
|
||
|
p = char(p) ;
|
||
|
values = zeros(0, epoch) ;
|
||
|
leg = {} ;
|
||
|
for f = {'train', 'val'}
|
||
|
f = char(f) ;
|
||
|
if isfield(stats.(f), p)
|
||
|
tmp = [stats.(f).(p)] ;
|
||
|
values(end+1,:) = tmp(1,:)' ;
|
||
|
leg{end+1} = f ;
|
||
|
end
|
||
|
end
|
||
|
subplot(1,numel(plots),find(strcmp(p,plots))) ;
|
||
|
plot(1:epoch, values','o-') ;
|
||
|
xlabel('epoch') ;
|
||
|
title(p) ;
|
||
|
legend(leg{:}) ;
|
||
|
grid on ;
|
||
|
end
|
||
|
drawnow ;
|
||
|
print(1, modelFigPath, '-dpdf') ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
if ( opts.computeBNMoment == true )
|
||
|
% Revert it back
|
||
|
for i = 1:numel(net.layers)
|
||
|
if ( isa( net.layers(i).block, 'dagnn.BatchNorm') )
|
||
|
net.layers(i).block.computeMoment = false;
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function [stats, prof] = process_epoch(net, state, opts, mode)
|
||
|
% -------------------------------------------------------------------------
|
||
|
|
||
|
% initialize empty momentum
|
||
|
if strcmp(mode,'train')
|
||
|
state.momentum = num2cell(zeros(1, numel(net.params))) ;
|
||
|
end
|
||
|
|
||
|
% move CNN to GPU as needed
|
||
|
numGpus = numel(opts.gpus) ;
|
||
|
if numGpus >= 1
|
||
|
net.move('gpu') ;
|
||
|
if strcmp(mode,'train')
|
||
|
state.momentum = cellfun(@gpuArray,state.momentum,'UniformOutput',false) ;
|
||
|
end
|
||
|
end
|
||
|
if numGpus > 1
|
||
|
mmap = map_gradients(opts.memoryMapFile, net, numGpus) ;
|
||
|
else
|
||
|
mmap = [] ;
|
||
|
end
|
||
|
|
||
|
% profile
|
||
|
if opts.profile
|
||
|
if numGpus <= 1
|
||
|
profile clear ;
|
||
|
profile on ;
|
||
|
else
|
||
|
mpiprofile reset ;
|
||
|
mpiprofile on ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
subset = state.(mode) ;
|
||
|
num = 0 ;
|
||
|
stats.num = 0 ; % return something even if subset = []
|
||
|
stats.time = 0 ;
|
||
|
adjustTime = 0 ;
|
||
|
|
||
|
start = tic ;
|
||
|
for t=1:opts.batchSize:numel(subset)
|
||
|
fprintf('%s: epoch %02d: %3d/%3d:', mode, state.epoch, ...
|
||
|
fix((t-1)/opts.batchSize)+1, ceil(numel(subset)/opts.batchSize)) ;
|
||
|
batchSize = min(opts.batchSize, numel(subset) - t + 1) ;
|
||
|
|
||
|
for s=1:opts.numSubBatches
|
||
|
% get this image batch and prefetch the next
|
||
|
batchStart = t + (labindex-1) + (s-1) * numlabs ;
|
||
|
batchEnd = min(t+opts.batchSize-1, numel(subset)) ;
|
||
|
batch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
|
||
|
num = num + numel(batch) ;
|
||
|
if numel(batch) == 0, continue ; end
|
||
|
|
||
|
inputs = state.getBatch(state.imdb, batch) ;
|
||
|
|
||
|
if opts.prefetch
|
||
|
if s == opts.numSubBatches
|
||
|
batchStart = t + (labindex-1) + opts.batchSize ;
|
||
|
batchEnd = min(t+2*opts.batchSize-1, numel(subset)) ;
|
||
|
else
|
||
|
batchStart = batchStart + numlabs ;
|
||
|
end
|
||
|
nextBatch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
|
||
|
state.getBatch(state.imdb, nextBatch) ;
|
||
|
end
|
||
|
|
||
|
if strcmp(mode, 'train')
|
||
|
net.mode = 'normal' ;
|
||
|
net.accumulateParamDers = (s ~= 1) ;
|
||
|
net.eval(inputs, opts.derOutputs) ;
|
||
|
else
|
||
|
net.mode = 'test' ;
|
||
|
net.eval(inputs) ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
% accumulate gradient
|
||
|
if strcmp(mode, 'train')
|
||
|
if ~isempty(mmap)
|
||
|
write_gradients(mmap, net) ;
|
||
|
labBarrier() ;
|
||
|
end
|
||
|
state = accumulate_gradients(state, net, opts, batchSize, mmap) ;
|
||
|
end
|
||
|
|
||
|
% get statistics
|
||
|
time = toc(start) + adjustTime ;
|
||
|
batchTime = time - stats.time ;
|
||
|
stats = opts.extractStatsFn(net) ;
|
||
|
stats.num = num ;
|
||
|
stats.time = time ;
|
||
|
currentSpeed = batchSize / batchTime ;
|
||
|
averageSpeed = (t + batchSize - 1) / time ;
|
||
|
if t == opts.batchSize + 1
|
||
|
% compensate for the first iteration, which is an outlier
|
||
|
adjustTime = 2*batchTime - time ;
|
||
|
stats.time = time + adjustTime ;
|
||
|
end
|
||
|
|
||
|
fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
|
||
|
for f = setdiff(fieldnames(stats)', {'num', 'time'})
|
||
|
f = char(f) ;
|
||
|
fprintf(' %s:', f) ;
|
||
|
fprintf(' %.3f', stats.(f)) ;
|
||
|
end
|
||
|
fprintf('\n') ;
|
||
|
end
|
||
|
|
||
|
if ~isempty(mmap)
|
||
|
unmap_gradients(mmap) ;
|
||
|
end
|
||
|
|
||
|
if opts.profile
|
||
|
if numGpus <= 1
|
||
|
prof = profile('info') ;
|
||
|
profile off ;
|
||
|
else
|
||
|
prof = mpiprofile('info');
|
||
|
mpiprofile off ;
|
||
|
end
|
||
|
else
|
||
|
prof = [] ;
|
||
|
end
|
||
|
|
||
|
net.reset() ;
|
||
|
net.move('cpu') ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function state = accumulate_gradients(state, net, opts, batchSize, mmap)
|
||
|
% -------------------------------------------------------------------------
|
||
|
numGpus = numel(opts.gpus) ;
|
||
|
otherGpus = setdiff(1:numGpus, labindex) ;
|
||
|
|
||
|
for p=1:numel(net.params)
|
||
|
|
||
|
% accumualte gradients from multiple labs (GPUs) if needed
|
||
|
if numGpus > 1
|
||
|
tag = net.params(p).name ;
|
||
|
for g = otherGpus
|
||
|
tmp = gpuArray(mmap.Data(g).(tag)) ;
|
||
|
net.params(p).der = net.params(p).der + tmp ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
switch net.params(p).trainMethod
|
||
|
|
||
|
case 'average' % mainly for batch normalization
|
||
|
thisLR = net.params(p).learningRate ;
|
||
|
net.params(p).value = ...
|
||
|
(1 - thisLR) * net.params(p).value + ...
|
||
|
(thisLR/batchSize/net.params(p).fanout) * net.params(p).der ;
|
||
|
|
||
|
case 'gradient'
|
||
|
thisDecay = opts.weightDecay * net.params(p).weightDecay ;
|
||
|
thisLR = state.learningRate * net.params(p).learningRate ;
|
||
|
state.momentum{p} = opts.momentum * state.momentum{p} ...
|
||
|
- thisDecay * net.params(p).value ...
|
||
|
- (1 / batchSize) * net.params(p).der ;
|
||
|
net.params(p).value = net.params(p).value + thisLR * state.momentum{p} ;
|
||
|
|
||
|
case 'otherwise'
|
||
|
error('Unknown training method ''%s'' for parameter ''%s''.', ...
|
||
|
net.params(p).trainMethod, ...
|
||
|
net.params(p).name) ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function mmap = map_gradients(fname, net, numGpus)
|
||
|
% -------------------------------------------------------------------------
|
||
|
format = {} ;
|
||
|
for i=1:numel(net.params)
|
||
|
format(end+1,1:3) = {'single', size(net.params(i).value), net.params(i).name} ;
|
||
|
end
|
||
|
format(end+1,1:3) = {'double', [3 1], 'errors'} ;
|
||
|
if ~exist(fname) && (labindex == 1)
|
||
|
f = fopen(fname,'wb') ;
|
||
|
for g=1:numGpus
|
||
|
for i=1:size(format,1)
|
||
|
fwrite(f,zeros(format{i,2},format{i,1}),format{i,1}) ;
|
||
|
end
|
||
|
end
|
||
|
fclose(f) ;
|
||
|
end
|
||
|
labBarrier() ;
|
||
|
mmap = memmapfile(fname, ...
|
||
|
'Format', format, ...
|
||
|
'Repeat', numGpus, ...
|
||
|
'Writable', true) ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function write_gradients(mmap, net)
|
||
|
% -------------------------------------------------------------------------
|
||
|
for i=1:numel(net.params)
|
||
|
mmap.Data(labindex).(net.params(i).name) = gather(net.params(i).der) ;
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function unmap_gradients(mmap)
|
||
|
% -------------------------------------------------------------------------
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function stats = accumulateStats(stats_)
|
||
|
% -------------------------------------------------------------------------
|
||
|
|
||
|
for s = {'train', 'val'}
|
||
|
s = char(s) ;
|
||
|
total = 0 ;
|
||
|
|
||
|
% initialize stats stucture with same fields and same order as
|
||
|
% stats_{1}
|
||
|
stats__ = stats_{1} ;
|
||
|
names = fieldnames(stats__.(s))' ;
|
||
|
values = zeros(1, numel(names)) ;
|
||
|
fields = cat(1, names, num2cell(values)) ;
|
||
|
stats.(s) = struct(fields{:}) ;
|
||
|
|
||
|
for g = 1:numel(stats_)
|
||
|
stats__ = stats_{g} ;
|
||
|
num__ = stats__.(s).num ;
|
||
|
total = total + num__ ;
|
||
|
|
||
|
for f = setdiff(fieldnames(stats__.(s))', 'num')
|
||
|
f = char(f) ;
|
||
|
stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
|
||
|
|
||
|
if g == numel(stats_)
|
||
|
stats.(s).(f) = stats.(s).(f) / total ;
|
||
|
end
|
||
|
end
|
||
|
end
|
||
|
stats.(s).num = total ;
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function stats = extractStats(net)
|
||
|
% -------------------------------------------------------------------------
|
||
|
sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
|
||
|
stats = struct() ;
|
||
|
for i = 1:numel(sel)
|
||
|
stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function saveState(fileName, net, stats)
|
||
|
% -------------------------------------------------------------------------
|
||
|
net_ = net ;
|
||
|
net = net_.saveobj() ;
|
||
|
save(fileName, 'net', 'stats') ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function [net, stats] = loadState(fileName)
|
||
|
% -------------------------------------------------------------------------
|
||
|
load(fileName, 'net', 'stats') ;
|
||
|
net = dagnn.DagNN.loadobj(net) ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function epoch = findLastCheckpoint(modelDir)
|
||
|
% -------------------------------------------------------------------------
|
||
|
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
|
||
|
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
|
||
|
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
|
||
|
epoch = max([epoch 0]) ;
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function switchFigure(n)
|
||
|
% -------------------------------------------------------------------------
|
||
|
if get(0,'CurrentFigure') ~= n
|
||
|
try
|
||
|
set(0,'CurrentFigure',n) ;
|
||
|
catch
|
||
|
figure(n) ;
|
||
|
end
|
||
|
end
|
||
|
|
||
|
% -------------------------------------------------------------------------
|
||
|
function prepareGPUs(opts, cold)
|
||
|
% -------------------------------------------------------------------------
|
||
|
numGpus = numel(opts.gpus) ;
|
||
|
if numGpus > 1
|
||
|
% check parallel pool integrity as it could have timed out
|
||
|
pool = gcp('nocreate') ;
|
||
|
if ~isempty(pool) && pool.NumWorkers ~= numGpus
|
||
|
delete(pool) ;
|
||
|
end
|
||
|
pool = gcp('nocreate') ;
|
||
|
if isempty(pool)
|
||
|
parpool('local', numGpus) ;
|
||
|
cold = true ;
|
||
|
end
|
||
|
if exist(opts.memoryMapFile)
|
||
|
delete(opts.memoryMapFile) ;
|
||
|
end
|
||
|
|
||
|
end
|
||
|
if numGpus >= 1 && cold
|
||
|
fprintf('%s: resetting GPU\n', mfilename)
|
||
|
if numGpus == 1
|
||
|
gpuDevice(opts.gpus)
|
||
|
else
|
||
|
spmd, gpuDevice(opts.gpus(labindex)), end
|
||
|
end
|
||
|
end
|
||
|
|
||
|
%end
|