Steganalysis/PhaseAwareNet_SRC/MatConvNet/matconvnet-1.0-beta20/examples/cnn_test_dag.m

342 lines
10 KiB
Mathematica
Raw Normal View History

2021-12-30 11:48:37 +00:00
function [state,stats] = cnn_test_dag( imdb, getBatch, varargin)
%CNN_TEST_DAG Demonstrates test a CNN using the DagNN wrapper
% CNN_TEST_DAG() is a slim version to CNN_TRAIN_DAG(), just do the
% testing of the final net in the export
opts.expDir = fullfile('data','exp') ;
opts.batchSize = 256 ;
opts.numSubBatches = 1 ;
opts.train = [] ;
opts.val = [] ;
opts.test = [];
opts.gpus = [] ;
opts.prefetch = false ;
opts.testEpoch = inf;
opts.testSelect = [1, 1, 1]; % (1) training; (2)validation; (3), testing
opts.saveResult = true;
opts.bnRefine = false;
opts.randomSeed = 0 ;
opts.stegoShuffle = false;
opts.cudnn = true ;
opts.extractStatsFn = @extractStats ;
opts = vl_argparse(opts, varargin) ;
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
if isempty(opts.test), opts.test = find(imdb.images.set==3); end
if isnan(opts.train), opts.train = [] ; end
% -------------------------------------------------------------------------
% Initialization
% -------------------------------------------------------------------------
state.getBatch = getBatch ;
% -------------------------------------------------------------------------
% Train and validate
% -------------------------------------------------------------------------
if ( opts.bnRefine )
modelPath = @(ep) fullfile(opts.expDir, sprintf('bn-epoch-%d.mat', ep));
resultPath = @(ep) fullfile(opts.expDir, sprintf('test-bn-epoch-%d.mat', ep));
else
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep));
resultPath = @(ep) fullfile(opts.expDir, sprintf('test-net-epoch-%d.mat', ep));
end
start = findLastCheckpoint(opts.expDir) ;
if( start < 1 )
error( 'Found no net' );
end
if start >= 1
start = min(start, opts.testEpoch);
fprintf('%s: testing by loading epoch name %s\n', mfilename, modelPath(start) );
net = loadState(modelPath(start)) ;
end
% Make sure that we use the estimated BN moments
for i = 1:numel(net.layers)
if ( isa( net.layers(i).block, 'dagnn.BatchNorm') )
net.layers(i).block.computeMoment = false;
end
end
for epoch = start
% Set the random seed based on the epoch and opts.randomSeed.
% This is important for reproducibility, including when training
% is restarted from a checkpoint.
rng(epoch + opts.randomSeed) ;
prepareGPUs(opts, true ) ;
% Train for one epoch.
state.epoch = epoch ;
% shuffle
if( opts.stegoShuffle )
N = numel(opts.train); % TRN
M = numel(opts.val); % VAL
K = numel(opts.test); % TST
Lab = max( 1, numel(opts.gpus));
% M and N must be even, and multiple Lab
assert( ( rem( N, 2*Lab ) == 0 ) & ...
( rem( M, 2*Lab ) == 0 ) & ...
( rem( K, 2*Lab ) == 0 ) );
seq = opts.train( 2*randperm(N/2) - 1 );
seq = reshape( seq, Lab, N/(2*Lab) );
state.train = reshape( [seq; seq+1], 1, N );
seq = opts.val( 2*randperm(M/2) - 1 );
seq = reshape( seq, Lab, M/(2*Lab) );
state.val = reshape( [seq; seq+1], 1, M );
seq = opts.test( 2*randperm(K/2) - 1 );
seq = reshape( seq, Lab, K/(2*Lab) );
state.test = reshape( [seq; seq+1], 1, K );
else
state.train = opts.train(randperm(numel(opts.train))) ;
state.val = opts.val(randperm(numel(opts.val))) ;
state.test = opts.test(randperm(numel(opts.test))) ;
% N = numel(opts.train); % TRN
% M = numel(opts.val); % VAL
% K = numel(opts.test); % TST
%
%
% state.train = opts.train([1:2:N, 2:2:N]);
% state.val = opts.val([1:2:M, 2:2:M]);
% state.test = opts.test([1:2:K, 2:2:K]);
end
state.imdb = imdb ;
if numel(opts.gpus) <= 1
if( opts.testSelect(1) )
stats.train = process_epoch(net, state, opts, 'train') ;
end
if( opts.testSelect(2) )
stats.val = process_epoch(net, state, opts, 'val') ;
end
if( opts.testSelect(3) )
stats.test = process_epoch(net, state, opts, 'test');
end
else
savedNet = net.saveobj() ;
spmd
net_ = dagnn.DagNN.loadobj(savedNet) ;
if( opts.testSelect(1) )
stats_.train = process_epoch(net_, state, opts, 'train') ;
end
if( opts.testSelect(2) )
stats_.val = process_epoch(net_, state, opts, 'val') ;
end
if( opts.testSelect(3) )
stats_.test = process_epoch(net_, state, opts, 'test');
end
if labindex == 1, savedNet_ = net_.saveobj() ; end
end
net = dagnn.DagNN.loadobj(savedNet_{1}) ;
stats__ = accumulateStats(stats_) ;
if( opts.testSelect(1) )
stats.train = stats__.train ;
end
if( opts.testSelect(2) )
stats.val = stats__.val ;
end
if( opts.testSelect(3) )
stats.test = stats__.test;
end
clear net_ stats_ stats__ savedNet savedNet_ ;
end
% save
if( opts.saveResult == true )
saveState(resultPath(epoch), net, stats, state) ;
end
end
% -------------------------------------------------------------------------
function stats = process_epoch(net, state, opts, mode)
% -------------------------------------------------------------------------
% move CNN to GPU as needed
numGpus = numel(opts.gpus) ;
if numGpus >= 1
net.move('gpu') ;
end
subset = state.(mode) ;
num = 0 ;
stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;
adjustTime = 0 ;
start = tic ;
for t=1:opts.batchSize:numel(subset)
fprintf('%s: epoch %02d: %3d/%3d:', mode, state.epoch, ...
fix((t-1)/opts.batchSize)+1, ceil(numel(subset)/opts.batchSize)) ;
batchSize = min(opts.batchSize, numel(subset) - t + 1) ;
for s=1:opts.numSubBatches
% get this image batch and prefetch the next
batchStart = t + (labindex-1) + (s-1) * numlabs ;
batchEnd = min(t+opts.batchSize-1, numel(subset)) ;
batch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
num = num + numel(batch) ;
if numel(batch) == 0, continue ; end
inputs = state.getBatch(state.imdb, batch) ;
if opts.prefetch
if s == opts.numSubBatches
batchStart = t + (labindex-1) + opts.batchSize ;
batchEnd = min(t+2*opts.batchSize-1, numel(subset)) ;
else
batchStart = batchStart + numlabs ;
end
nextBatch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
state.getBatch(state.imdb, nextBatch) ;
end
net.mode = 'test' ;
net.eval(inputs) ;
end
% get statistics
time = toc(start) + adjustTime ;
batchTime = time - stats.time ;
stats = opts.extractStatsFn(net) ;
stats.num = num ;
stats.time = time ;
currentSpeed = batchSize / batchTime ;
averageSpeed = (t + batchSize - 1) / time ;
if t == opts.batchSize + 1
% compensate for the first iteration, which is an outlier
adjustTime = 2*batchTime - time ;
stats.time = time + adjustTime ;
end
fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
for f = setdiff(fieldnames(stats)', {'num', 'time'})
f = char(f) ;
fprintf(' %s:', f) ;
fprintf(' %.3f', stats.(f)) ;
end
fprintf('\n') ;
end
net.reset() ;
net.move('cpu') ;
% -------------------------------------------------------------------------
function stats = accumulateStats(stats_)
% -------------------------------------------------------------------------
for s = {'train', 'val', 'test'}
s = char(s) ;
total = 0 ;
% initialize stats stucture with same fields and same order as
% stats_{1}
stats__ = stats_{1} ;
if ( ~isfield(stats__, s) )
continue;
end
names = fieldnames(stats__.(s))' ;
values = zeros(1, numel(names)) ;
fields = cat(1, names, num2cell(values)) ;
stats.(s) = struct(fields{:}) ;
for g = 1:numel(stats_)
stats__ = stats_{g} ;
num__ = stats__.(s).num ;
total = total + num__ ;
for f = setdiff(fieldnames(stats__.(s))', 'num')
f = char(f) ;
stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;
if g == numel(stats_)
stats.(s).(f) = stats.(s).(f) / total ;
end
end
end
stats.(s).num = total ;
end
% -------------------------------------------------------------------------
function stats = extractStats(net)
% -------------------------------------------------------------------------
sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ;
stats = struct() ;
for i = 1:numel(sel)
stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
end
% -------------------------------------------------------------------------
function saveState(fileName, net, stats, state )
% -------------------------------------------------------------------------
net_ = net ;
net = net_.saveobj() ;
save(fileName, 'net', 'stats', 'state') ;
% -------------------------------------------------------------------------
function [net, stats] = loadState(fileName)
% -------------------------------------------------------------------------
load(fileName, 'net', 'stats') ;
net = dagnn.DagNN.loadobj(net) ;
% -------------------------------------------------------------------------
function epoch = findLastCheckpoint(modelDir)
% -------------------------------------------------------------------------
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;
% -------------------------------------------------------------------------
function prepareGPUs(opts, cold)
% -------------------------------------------------------------------------
numGpus = numel(opts.gpus) ;
if numGpus > 1
% check parallel pool integrity as it could have timed out
pool = gcp('nocreate') ;
if ~isempty(pool) && pool.NumWorkers ~= numGpus
delete(pool) ;
end
pool = gcp('nocreate') ;
if isempty(pool)
parpool('local', numGpus) ;
cold = true ;
end
end
if numGpus >= 1 && cold
fprintf('%s: resetting GPU\n', mfilename)
if numGpus == 1
gpuDevice(opts.gpus)
else
spmd, gpuDevice(opts.gpus(labindex)), end
end
end
%end