Steganalysis/YeNet-Tensorflow/utils.py

299 lines
15 KiB
Python

import tensorflow as tf
import numpy as np
import sys
import time
from glob import glob
from functools import partial
import os
from os.path import expanduser
home = expanduser("~")
user = home.split('/')[-1]
sys.path.append(home + '/tflib/')
from queues import *
from generator import *
def optimistic_restore(session, save_file, \
graph=tf.get_default_graph()):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
restore_vars = []
for var_name, saved_var_name in var_names:
curr_var = graph.get_tensor_by_name(var_name)
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
opt_saver = tf.train.Saver(restore_vars)
opt_saver.restore(session, save_file)
class average_summary(object):
def __init__(self, variable, name, num_iterations):
self.sum_variable = tf.get_variable(name, shape=[], \
initializer=tf.constant_initializer(0.), \
dtype='float32', \
trainable=False, \
collections=[tf.GraphKeys.LOCAL_VARIABLES])
with tf.control_dependencies([variable]):
self.increment_op = tf.assign_add(self.sum_variable, variable)
self.mean_variable = self.sum_variable / float(num_iterations)
self.summary = tf.summary.scalar(name, self.mean_variable)
with tf.control_dependencies([self.summary]):
self.reset_variable_op = tf.assign(self.sum_variable, 0)
def add_summary(self, sess, writer, step):
s, _ = sess.run([self.summary, self.reset_variable_op])
writer.add_summary(s, step)
class Model(object):
def __init__(self, is_training=None, data_format='NCHW'):
self.data_format = data_format
if is_training is None:
self.is_training = tf.get_variable('is_training', dtype=tf.bool, \
initializer=tf.constant_initializer(True), \
trainable=False)
else:
self.is_training = is_training
def _build_model(self, inputs):
raise NotImplementedError('Here is your model definition')
def _build_losses(self, labels):
self.labels = tf.cast(labels, tf.int64)
with tf.variable_scope('loss'):
oh = tf.one_hot(self.labels, 2)
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( \
labels=oh, logits=self.outputs))
with tf.variable_scope('accuracy'):
am = tf.argmax(self.outputs, 1)
equal = tf.equal(am, self.labels)
self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))
return self.loss, self.accuracy
def train(model_class, train_gen, valid_gen, train_batch_size, \
valid_batch_size, valid_ds_size, optimizer, \
train_interval, valid_interval, max_iter, \
save_interval, log_path, num_runner_threads=1, \
load_path=None):
tf.reset_default_graph()
train_runner = GeneratorRunner(train_gen, train_batch_size * 10)
valid_runner = GeneratorRunner(valid_gen, valid_batch_size * 10)
is_training = tf.get_variable('is_training', dtype=tf.bool, \
initializer=True, trainable=False)
if train_batch_size == valid_batch_size:
batch_size = train_batch_size
disable_training_op = tf.assign(is_training, False)
enable_training_op = tf.assign(is_training, True)
else:
batch_size = tf.get_variable('batch_size', dtype=tf.int32, \
initializer=train_batch_size, \
trainable=False, \
collections=[tf.GraphKeys.LOCAL_VARIABLES])
disable_training_op = tf.group(tf.assign(is_training, False), \
tf.assign(batch_size, valid_batch_size))
enable_training_op = tf.group(tf.assign(is_training, True), \
tf.assign(batch_size, train_batch_size))
img_batch, label_batch = queueSelection([valid_runner, train_runner], \
tf.cast(is_training, tf.int32), \
batch_size)
model = model_class(is_training, 'NCHW')
model._build_model(img_batch)
loss, accuracy = model._build_losses(label_batch)
regularization_losses = tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES)
regularized_loss = tf.add_n([loss] + regularization_losses)
train_loss_s = average_summary(loss, 'train_loss', train_interval)
train_accuracy_s = average_summary(accuracy, 'train_accuracy', \
train_interval)
valid_loss_s = average_summary(loss, 'valid_loss', \
float(valid_ds_size) / float(valid_batch_size))
valid_accuracy_s = average_summary(accuracy, 'valid_accuracy', \
float(valid_ds_size) / float(valid_batch_size))
global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \
initializer=tf.constant_initializer(0), \
trainable=False)
minimize_op = optimizer.minimize(regularized_loss, global_step)
train_op = tf.group(minimize_op, train_loss_s.increment_op, \
train_accuracy_s.increment_op)
increment_valid = tf.group(valid_loss_s.increment_op, \
valid_accuracy_s.increment_op)
init_op = tf.group(tf.global_variables_initializer(), \
tf.local_variables_initializer())
saver = tf.train.Saver(max_to_keep=10000)
with tf.Session() as sess:
sess.run(init_op)
if load_path is not None:
loader = tf.train.Saver(reshape=True)
loader.restore(sess, load_path)
train_runner.start_threads(sess, num_runner_threads)
valid_runner.start_threads(sess, 1)
writer = tf.summary.FileWriter(log_path + '/LogFile/', \
sess.graph)
start = sess.run(global_step)
sess.run(disable_training_op)
sess.run([valid_loss_s.reset_variable_op, \
valid_accuracy_s.reset_variable_op, \
train_loss_s.reset_variable_op, \
train_accuracy_s.reset_variable_op])
_time = time.time()
for j in range(0, valid_ds_size, valid_batch_size):
sess.run([increment_valid])
_acc_val = sess.run(valid_accuracy_s.mean_variable)
print "validation:", _acc_val, " | ", \
"duration:", time.time() - _time, \
"seconds long"
valid_accuracy_s.add_summary(sess, writer, start)
valid_loss_s.add_summary(sess, writer, start)
sess.run(enable_training_op)
print valid_interval
for i in xrange(start+1, max_iter+1):
sess.run(train_op)
if i % train_interval == 0:
train_loss_s.add_summary(sess, writer, i)
train_accuracy_s.add_summary(sess, writer, i)
if i % valid_interval == 0:
sess.run(disable_training_op)
for j in range(0, valid_ds_size, valid_batch_size):
sess.run([increment_valid])
valid_loss_s.add_summary(sess, writer, i)
valid_accuracy_s.add_summary(sess, writer, i)
sess.run(enable_training_op)
if i % save_interval == 0:
saver.save(sess, log_path + '/Model_' + str(i) + '.ckpt')
def test_dataset(model_class, gen, batch_size, ds_size, load_path):
tf.reset_default_graph()
runner = GeneratorRunner(gen, batch_size * 10)
img_batch, label_batch = runner.get_batched_inputs(batch_size)
model = model_class(False, 'NCHW')
model._build_model(img_batch)
loss, accuracy = model._build_losses(label_batch)
loss_summary = average_summary(loss, 'loss', \
float(ds_size) / float(batch_size))
accuracy_summary = average_summary(accuracy, 'accuracy', \
float(ds_size) / float(batch_size))
increment_op = tf.group(loss_summary.increment_op, \
accuracy_summary.increment_op)
global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \
initializer=tf.constant_initializer(0), \
trainable=False)
init_op = tf.group(tf.global_variables_initializer(), \
tf.local_variables_initializer())
saver = tf.train.Saver(max_to_keep=10000)
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, load_path)
runner.start_threads(sess, 1)
for j in range(0, ds_size, batch_size):
sess.run(increment_op)
mean_loss, mean_accuracy = sess.run([loss_summary.mean_variable ,\
accuracy_summary.mean_variable])
print "Accuracy:", mean_accuracy, " | Loss:", mean_loss
def find_best(model_class, valid_gen, test_gen, valid_batch_size, \
test_batch_size, valid_ds_size, test_ds_size, load_paths):
tf.reset_default_graph()
valid_runner = GeneratorRunner(valid_gen, valid_batch_size * 30)
img_batch, label_batch = valid_runner.get_batched_inputs(valid_batch_size)
model = model_class(False, 'NCHW')
model._build_model(img_batch)
loss, accuracy = model._build_losses(label_batch)
loss_summary = average_summary(loss, 'loss', \
float(valid_ds_size) \
/ float(valid_batch_size))
accuracy_summary = average_summary(accuracy, 'accuracy', \
float(valid_ds_size) \
/ float(valid_batch_size))
increment_op = tf.group(loss_summary.increment_op, \
accuracy_summary.increment_op)
global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \
initializer=tf.constant_initializer(0), \
trainable=False)
init_op = tf.group(tf.global_variables_initializer(), \
tf.local_variables_initializer())
saver = tf.train.Saver(max_to_keep=10000)
accuracy_arr = []
loss_arr = []
print "validation"
for load_path in load_paths:
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, load_path)
valid_runner.start_threads(sess, 1)
_time = time.time()
for j in range(0, valid_ds_size, valid_batch_size):
sess.run(increment_op)
mean_loss, mean_accuracy = sess.run([loss_summary.mean_variable ,\
accuracy_summary.mean_variable])
accuracy_arr.append(mean_accuracy)
loss_arr.append(mean_loss)
print load_path
print "Accuracy:", accuracy_arr[-1], "| Loss:", loss_arr[-1], \
"in", time.time() - _time, "seconds."
argmax = np.argmax(accuracy_arr)
print "best savestate:", load_paths[argmax], "with", \
accuracy_arr[argmax], "accuracy and", loss_arr[argmax], \
"loss on validation"
print "test:"
test_dataset(model_class, test_gen, test_batch_size, test_ds_size, \
load_paths[argmax])
return argmax, accuracy_arr, loss_arr
def extract_stats_outputs(model_class, gen, batch_size, ds_size, load_path):
tf.reset_default_graph()
runner = GeneratorRunner(gen, batch_size * 10)
img_batch, label_batch = runner.get_batched_inputs(batch_size)
model = model_class(False, 'NCHW')
model._build_model(img_batch)
global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \
initializer=tf.constant_initializer(0), \
trainable=False)
init_op = tf.group(tf.global_variables_initializer(), \
tf.local_variables_initializer())
saver = tf.train.Saver(max_to_keep=10000)
stats_outputs_arr = np.empty([ds_size, \
model.stats_outputs.get_shape().as_list()[1]])
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, load_path)
runner.start_threads(sess, 1)
for j in range(0, ds_size, batch_size):
stats_outputs_arr[j:j+batch_size] = sess.run(model.stats_outputs)
return stats_outputs_arr
def stats_outputs_all_datasets(model_class, ds_head_dir, payload, \
algorithm, load_path, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir + '/')
payload_str = ''.join(str(payload).strip('.'))
train_ds_size = len(glob(ds_head_dir + '/train/cover/*'))
valid_ds_size = len(glob(ds_head_dir + '/valid/cover/*'))
test_ds_size = len(glob(ds_head_dir + '/test/cover/*'))
train_gen = partial(gen_all_flip_and_rot, ds_head_dir + \
'/train/cover/', ds_head_dir + '/train/' + \
algorithm + '/payload' + payload_str + '/stego/')
valid_gen = partial(gen_valid, ds_head_dir + '/valid/cover/', \
ds_head_dir + '/valid/' + algorithm + \
'/payload' + payload_str + '/stego/')
test_gen = partial(gen_valid, ds_head_dir + '/test/cover/', \
ds_head_dir + '/test/' + algorithm + \
'/payload' + payload_str + '/stego/')
print "train..."
stats_outputs = extract_stats_outputs(model_class, train_gen, 16, \
train_ds_size * 2 * 4 * 2, \
load_path)
stats_shape = stats_outputs.shape
stats_outputs = stats_outputs.reshape(train_ds_size, 2, 4, \
2, stats_shape[-1])
stats_outputs = np.transpose(stats_outputs, axes=[0,3,2,1,4])
np.save(save_dir + '/train.npy', stats_outputs)
print "validation..."
stats_outputs = extract_stats_outputs(model_class, valid_gen, 16, \
valid_ds_size * 2, load_path)
np.save(save_dir + '/valid.npy', stats_outputs)
print "test..."
stats_outputs = extract_stats_outputs(model_class, test_gen, 16, \
test_ds_size * 2, load_path)
np.save(save_dir + '/test.npy', stats_outputs)