function [net,stats] = cnn_train_dag(net, imdb, getBatch, varargin) %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with % the DagNN wrapper instead of the SimpleNN wrapper. % Copyright (C) 2014-16 Andrea Vedaldi. % All rights reserved. % % This file is part of the VLFeat library and is made available under % the terms of the BSD license (see the COPYING file). opts.expDir = fullfile('data','exp') ; opts.continue = true ; opts.batchSize = 256 ; opts.numSubBatches = 1 ; opts.train = [] ; opts.val = [] ; opts.gpus = [] ; opts.prefetch = false ; opts.numEpochs = 300 ; opts.learningRate = 0.001 ; opts.weightDecay = 0.0005 ; opts.momentum = 0.9 ; opts.randomSeed = 0 ; opts.stegoShuffle = false; opts.computeBNMoment = false; opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ; opts.profile = false ; opts.cudnn = true ; opts.derOutputs = {'objective', 1} ; opts.extractStatsFn = @extractStats ; opts.plotStatistics = true; opts = vl_argparse(opts, varargin) ; if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end if isnan(opts.train), opts.train = [] ; end % ------------------------------------------------------------------------- % Initialization % ------------------------------------------------------------------------- evaluateMode = isempty(opts.train) ; if ~evaluateMode if isempty(opts.derOutputs) error('DEROUTPUTS must be specified when training.\n') ; end end state.getBatch = getBatch ; stats = [] ; % ------------------------------------------------------------------------- % Train and validate % ------------------------------------------------------------------------- modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; start = opts.continue * findLastCheckpoint(opts.expDir) ; if start >= 1 fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; [net, stats] = loadState(modelPath(start)) ; end if ( opts.computeBNMoment == true ) % Validation without using the moving average of BN momemnts for i = 1:numel(net.layers) if ( isa( net.layers(i).block, 'dagnn.BatchNorm') ) net.layers(i).block.computeMoment = true; end end end for epoch=start+1:opts.numEpochs % Set the random seed based on the epoch and opts.randomSeed. % This is important for reproducibility, including when training % is restarted from a checkpoint. rng(epoch + opts.randomSeed) ; prepareGPUs(opts, epoch == start+1) ; % Train for one epoch. state.epoch = epoch ; state.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; %state.train = opts.train(randperm(numel(opts.train))) ; % shuffle %state.val = opts.val(randperm(numel(opts.val))) ; % shuffle if( opts.stegoShuffle == 1 ) N = numel(opts.train); M = numel(opts.val); Lab = max(1, numel(opts.gpus)); % M and N must be even, and multiple Lab assert( ( rem(N, 2*Lab) == 0 ) & ( rem(M, 2*Lab) == 0 ) ); % state.train(1:2:N) = opts.train(2*randperm(N/2) - 1); % state.train(2:2:N) = state.train(1:2:N) + 1; % % state.val(1:2:M) = opts.val(2*randperm(M/2) - 1); % state.val(2:2:M) = state.val(1:2:M) + 1; seq = opts.train(2*randperm(N/2) - 1); seq = reshape(seq, Lab, N/(2*Lab)); state.train = reshape([seq; seq+1], 1, N); seq = opts.val(2*randperm(M/2) - 1); seq = reshape(seq, Lab, M/(2*Lab)); state.val = reshape([seq; seq+1], 1, M); elseif ( opts.stegoShuffle < 0 ) % for regression task K = abs( opts.stegoShuffle ); M = numel(opts.train)/K; seq = K * ( randperm(M) - 1 ); seq = [seq + 1; seq + 2; seq + 3; seq + 4; seq + 5; seq + 6]; seq = reshape(seq, numel(seq), 1); state.train = opts.train(seq); N = numel(opts.val)/K; seq = K * ( randperm(N) - 1 ) ; seq = [seq + 1; seq + 2; seq + 3; seq + 4; seq + 5; seq + 6]; seq = reshape(seq, numel(seq), 1); state.val = opts.val(seq); else state.train = opts.train(randperm(numel(opts.train))) ; state.val = opts.val(randperm(numel(opts.val))) ; end state.imdb = imdb ; if numel(opts.gpus) <= 1 [stats.train(epoch),prof] = process_epoch(net, state, opts, 'train') ; stats.val(epoch) = process_epoch(net, state, opts, 'val') ; if opts.profile profview(0,prof) ; keyboard ; end else savedNet = net.saveobj() ; spmd net_ = dagnn.DagNN.loadobj(savedNet) ; [stats_.train, prof_] = process_epoch(net_, state, opts, 'train') ; stats_.val = process_epoch(net_, state, opts, 'val') ; if labindex == 1, savedNet_ = net_.saveobj() ; end end net = dagnn.DagNN.loadobj(savedNet_{1}) ; stats__ = accumulateStats(stats_) ; stats.train(epoch) = stats__.train ; stats.val(epoch) = stats__.val ; if opts.profile mpiprofile('viewer', [prof_{:,1}]) ; keyboard ; end clear net_ stats_ stats__ savedNet savedNet_ ; end % save if ~evaluateMode saveState(modelPath(epoch), net, stats) ; end if opts.plotStatistics switchFigure(1) ; clf ; plots = setdiff(... cat(2,... fieldnames(stats.train)', ... fieldnames(stats.val)'), {'num', 'time'}) ; for p = plots p = char(p) ; values = zeros(0, epoch) ; leg = {} ; for f = {'train', 'val'} f = char(f) ; if isfield(stats.(f), p) tmp = [stats.(f).(p)] ; values(end+1,:) = tmp(1,:)' ; leg{end+1} = f ; end end subplot(1,numel(plots),find(strcmp(p,plots))) ; plot(1:epoch, values','o-') ; xlabel('epoch') ; title(p) ; legend(leg{:}) ; grid on ; end drawnow ; print(1, modelFigPath, '-dpdf') ; end end if ( opts.computeBNMoment == true ) % Revert it back for i = 1:numel(net.layers) if ( isa( net.layers(i).block, 'dagnn.BatchNorm') ) net.layers(i).block.computeMoment = false; end end end % ------------------------------------------------------------------------- function [stats, prof] = process_epoch(net, state, opts, mode) % ------------------------------------------------------------------------- % initialize empty momentum if strcmp(mode,'train') state.momentum = num2cell(zeros(1, numel(net.params))) ; end % move CNN to GPU as needed numGpus = numel(opts.gpus) ; if numGpus >= 1 net.move('gpu') ; if strcmp(mode,'train') state.momentum = cellfun(@gpuArray,state.momentum,'UniformOutput',false) ; end end if numGpus > 1 mmap = map_gradients(opts.memoryMapFile, net, numGpus) ; else mmap = [] ; end % profile if opts.profile if numGpus <= 1 profile clear ; profile on ; else mpiprofile reset ; mpiprofile on ; end end subset = state.(mode) ; num = 0 ; stats.num = 0 ; % return something even if subset = [] stats.time = 0 ; adjustTime = 0 ; start = tic ; for t=1:opts.batchSize:numel(subset) fprintf('%s: epoch %02d: %3d/%3d:', mode, state.epoch, ... fix((t-1)/opts.batchSize)+1, ceil(numel(subset)/opts.batchSize)) ; batchSize = min(opts.batchSize, numel(subset) - t + 1) ; for s=1:opts.numSubBatches % get this image batch and prefetch the next batchStart = t + (labindex-1) + (s-1) * numlabs ; batchEnd = min(t+opts.batchSize-1, numel(subset)) ; batch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ; num = num + numel(batch) ; if numel(batch) == 0, continue ; end inputs = state.getBatch(state.imdb, batch) ; if opts.prefetch if s == opts.numSubBatches batchStart = t + (labindex-1) + opts.batchSize ; batchEnd = min(t+2*opts.batchSize-1, numel(subset)) ; else batchStart = batchStart + numlabs ; end nextBatch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ; state.getBatch(state.imdb, nextBatch) ; end if strcmp(mode, 'train') net.mode = 'normal' ; net.accumulateParamDers = (s ~= 1) ; net.eval(inputs, opts.derOutputs) ; else net.mode = 'test' ; net.eval(inputs) ; end end % accumulate gradient if strcmp(mode, 'train') if ~isempty(mmap) write_gradients(mmap, net) ; labBarrier() ; end state = accumulate_gradients(state, net, opts, batchSize, mmap) ; end % get statistics time = toc(start) + adjustTime ; batchTime = time - stats.time ; stats = opts.extractStatsFn(net) ; stats.num = num ; stats.time = time ; currentSpeed = batchSize / batchTime ; averageSpeed = (t + batchSize - 1) / time ; if t == opts.batchSize + 1 % compensate for the first iteration, which is an outlier adjustTime = 2*batchTime - time ; stats.time = time + adjustTime ; end fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ; for f = setdiff(fieldnames(stats)', {'num', 'time'}) f = char(f) ; fprintf(' %s:', f) ; fprintf(' %.3f', stats.(f)) ; end fprintf('\n') ; end if ~isempty(mmap) unmap_gradients(mmap) ; end if opts.profile if numGpus <= 1 prof = profile('info') ; profile off ; else prof = mpiprofile('info'); mpiprofile off ; end else prof = [] ; end net.reset() ; net.move('cpu') ; % ------------------------------------------------------------------------- function state = accumulate_gradients(state, net, opts, batchSize, mmap) % ------------------------------------------------------------------------- numGpus = numel(opts.gpus) ; otherGpus = setdiff(1:numGpus, labindex) ; for p=1:numel(net.params) % accumualte gradients from multiple labs (GPUs) if needed if numGpus > 1 tag = net.params(p).name ; for g = otherGpus tmp = gpuArray(mmap.Data(g).(tag)) ; net.params(p).der = net.params(p).der + tmp ; end end switch net.params(p).trainMethod case 'average' % mainly for batch normalization thisLR = net.params(p).learningRate ; net.params(p).value = ... (1 - thisLR) * net.params(p).value + ... (thisLR/batchSize/net.params(p).fanout) * net.params(p).der ; case 'gradient' thisDecay = opts.weightDecay * net.params(p).weightDecay ; thisLR = state.learningRate * net.params(p).learningRate ; state.momentum{p} = opts.momentum * state.momentum{p} ... - thisDecay * net.params(p).value ... - (1 / batchSize) * net.params(p).der ; net.params(p).value = net.params(p).value + thisLR * state.momentum{p} ; case 'otherwise' error('Unknown training method ''%s'' for parameter ''%s''.', ... net.params(p).trainMethod, ... net.params(p).name) ; end end % ------------------------------------------------------------------------- function mmap = map_gradients(fname, net, numGpus) % ------------------------------------------------------------------------- format = {} ; for i=1:numel(net.params) format(end+1,1:3) = {'single', size(net.params(i).value), net.params(i).name} ; end format(end+1,1:3) = {'double', [3 1], 'errors'} ; if ~exist(fname) && (labindex == 1) f = fopen(fname,'wb') ; for g=1:numGpus for i=1:size(format,1) fwrite(f,zeros(format{i,2},format{i,1}),format{i,1}) ; end end fclose(f) ; end labBarrier() ; mmap = memmapfile(fname, ... 'Format', format, ... 'Repeat', numGpus, ... 'Writable', true) ; % ------------------------------------------------------------------------- function write_gradients(mmap, net) % ------------------------------------------------------------------------- for i=1:numel(net.params) mmap.Data(labindex).(net.params(i).name) = gather(net.params(i).der) ; end % ------------------------------------------------------------------------- function unmap_gradients(mmap) % ------------------------------------------------------------------------- % ------------------------------------------------------------------------- function stats = accumulateStats(stats_) % ------------------------------------------------------------------------- for s = {'train', 'val'} s = char(s) ; total = 0 ; % initialize stats stucture with same fields and same order as % stats_{1} stats__ = stats_{1} ; names = fieldnames(stats__.(s))' ; values = zeros(1, numel(names)) ; fields = cat(1, names, num2cell(values)) ; stats.(s) = struct(fields{:}) ; for g = 1:numel(stats_) stats__ = stats_{g} ; num__ = stats__.(s).num ; total = total + num__ ; for f = setdiff(fieldnames(stats__.(s))', 'num') f = char(f) ; stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; if g == numel(stats_) stats.(s).(f) = stats.(s).(f) / total ; end end end stats.(s).num = total ; end % ------------------------------------------------------------------------- function stats = extractStats(net) % ------------------------------------------------------------------------- sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ; stats = struct() ; for i = 1:numel(sel) stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; end % ------------------------------------------------------------------------- function saveState(fileName, net, stats) % ------------------------------------------------------------------------- net_ = net ; net = net_.saveobj() ; save(fileName, 'net', 'stats') ; % ------------------------------------------------------------------------- function [net, stats] = loadState(fileName) % ------------------------------------------------------------------------- load(fileName, 'net', 'stats') ; net = dagnn.DagNN.loadobj(net) ; % ------------------------------------------------------------------------- function epoch = findLastCheckpoint(modelDir) % ------------------------------------------------------------------------- list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; epoch = max([epoch 0]) ; % ------------------------------------------------------------------------- function switchFigure(n) % ------------------------------------------------------------------------- if get(0,'CurrentFigure') ~= n try set(0,'CurrentFigure',n) ; catch figure(n) ; end end % ------------------------------------------------------------------------- function prepareGPUs(opts, cold) % ------------------------------------------------------------------------- numGpus = numel(opts.gpus) ; if numGpus > 1 % check parallel pool integrity as it could have timed out pool = gcp('nocreate') ; if ~isempty(pool) && pool.NumWorkers ~= numGpus delete(pool) ; end pool = gcp('nocreate') ; if isempty(pool) parpool('local', numGpus) ; cold = true ; end if exist(opts.memoryMapFile) delete(opts.memoryMapFile) ; end end if numGpus >= 1 && cold fprintf('%s: resetting GPU\n', mfilename) if numGpus == 1 gpuDevice(opts.gpus) else spmd, gpuDevice(opts.gpus(labindex)), end end end %end