179 lines
6.9 KiB
Plaintext
179 lines
6.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'\n",
|
|
"os.environ['CUDA_VISIBLE_DEVICES'] = '1' # set a GPU (with GPU Number)\n",
|
|
"home = os.path.expanduser(\"~\")\n",
|
|
"sys.path.append(home + '/tflib/') # path for 'tflib' folder\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from scipy.io import loadmat\n",
|
|
"from SCA_SRNet_Spatial import * # use 'SCA_SRNet_JPEG' for JPEG domain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def trnGen(cover_path, stego_path, cover_beta_path, stego_beta_path, thread_idx=0, n_threads=1):\n",
|
|
" IL=os.listdir(cover_path)\n",
|
|
" img_shape = plt.imread(cover_path +IL[0]).shape\n",
|
|
" batch = np.empty((2, img_shape[0], img_shape[1], 2), dtype='float32')\n",
|
|
" while True:\n",
|
|
" indx = np.random.permutation(len(IL))\n",
|
|
" for i in indx:\n",
|
|
" batch[0,:,:,0] = plt.imread(cover_path + IL[i]) # use loadmat for loading JPEG decompressed images \n",
|
|
" batch[0,:,:,1] = loadmat(cover_beta_path + IL[i].replace('pgm','mat'))['Beta'] # adjust for JPEG images\n",
|
|
" batch[1,:,:,0] = plt.imread(stego_path + IL[i]) # use loadmat for loading JPEG decompressed images \n",
|
|
" batch[1,:,:,1] = loadmat(stego_beta_path + IL[i].replace('pgm','mat'))['Beta'] # adjust for JPEG images\n",
|
|
" rot = random.randint(0,3)\n",
|
|
" if rand() < 0.5:\n",
|
|
" yield [np.rot90(batch, rot, axes=[1,2]), np.array([0,1], dtype='uint8')]\n",
|
|
" else:\n",
|
|
" yield [np.flip(np.rot90(batch, rot, axes=[1,2]), axis=2), np.array([0,1], dtype='uint8')] "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def valGen(cover_path, stego_path, cover_beta_path, stego_beta_path, thread_idx=0, n_threads=1):\n",
|
|
" IL=os.listdir(cover_path)\n",
|
|
" img_shape = plt.imread(cover_path +IL[0]).shape\n",
|
|
" batch = np.empty((2, img_shape[0], img_shape[1], 2), dtype='float32')\n",
|
|
" while True:\n",
|
|
" for i in range(len(IL)):\n",
|
|
" batch[0,:,:,0] = plt.imread(cover_path + IL[i]) # use loadmat for loading JPEG decompressed images \n",
|
|
" batch[0,:,:,1] = loadmat(cover_beta_path + IL[i].replace('pgm','mat'))['Beta'] # adjust for JPEG images\n",
|
|
" batch[1,:,:,0] = plt.imread(stego_path + IL[i]) # use loadmat for loading JPEG decompressed images \n",
|
|
" batch[1,:,:,1] = loadmat(stego_beta_path + IL[i].replace('pgm','mat'))['Beta'] # adjust for JPEG images\n",
|
|
" yield [batch, np.array([0,1], dtype='uint8') ]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_batch_size = 32\n",
|
|
"valid_batch_size = 40\n",
|
|
"max_iter = 500000\n",
|
|
"train_interval=100\n",
|
|
"valid_interval=5000\n",
|
|
"save_interval=5000\n",
|
|
"num_runner_threads=10\n",
|
|
"\n",
|
|
"# save Betas as '.mat' files with variable name \"Beta\" and put them in thier corresponding directoroies. Make sure \n",
|
|
"# all mat files in the directories can be loaded in Python without any errors.\n",
|
|
"\n",
|
|
"TRAIN_COVER_DIR = '/media/Cover_TRN/'\n",
|
|
"TRAIN_STEGO_DIR = '/media/Stego_WOW_0.5_TRN/'\n",
|
|
"TRAIN_COVER_BETA_DIR = '/media/Beta_Cover_WOW_0.5_TRN/'\n",
|
|
"TRAIN_STEGO_BETA_DIR = '/media/Beta_Stego_WOW_0.5_TRN/'\n",
|
|
"\n",
|
|
"VALID_COVER_DIR = '/media/Cover_VAL/'\n",
|
|
"VALID_STEGO_DIR = '/media/Stego_WOW_0.5_VAL/'\n",
|
|
"VALID_COVER_BETA_DIR = '/media/Beta_Cover_WOW_0.5_VAL/'\n",
|
|
"VALID_STEGO_BETA_DIR = '/media/Beta_Stego_WOW_0.5_VAL/'\n",
|
|
"\n",
|
|
"train_gen = partial(trnGen, \\\n",
|
|
" TRAIN_COVER_DIR, TRAIN_STEGO_DIR, TRAIN_COVER_BETA_DIR, TRAIN_STEGO_BETA_DIR) \n",
|
|
"valid_gen = partial(valGen, \\\n",
|
|
" VALID_COVER_DIR, VALID_STEGO_DIR, VALID_COVER_BETA_DIR, VALID_STEGO_BETA_DIR)\n",
|
|
"\n",
|
|
"LOG_DIR= '/media/LogFiles/SCA_WOW_0.5' # path for a log direcotry\n",
|
|
"# load_path= LOG_DIR + 'Model_460000.ckpt' # continue training from a specific checkpoint\n",
|
|
"load_path=None # training from scratch\n",
|
|
"\n",
|
|
"if not os.path.exists(LOG_DIR):\n",
|
|
" os.makedirs(LOG_DIR)\n",
|
|
" \n",
|
|
"train_ds_size = len(glob(TRAIN_COVER_DIR + '/*')) * 2\n",
|
|
"valid_ds_size = len(glob(VALID_COVER_DIR +'/*')) * 2\n",
|
|
"print 'train_ds_size: %i'%train_ds_size\n",
|
|
"print 'valid_ds_size: %i'%valid_ds_size\n",
|
|
"\n",
|
|
"if valid_ds_size % valid_batch_size != 0:\n",
|
|
" raise ValueError(\"change batch size for validation\")\n",
|
|
"\n",
|
|
"optimizer = AdamaxOptimizer\n",
|
|
"boundaries = [400000] # learning rate adjustment at iteration 400K\n",
|
|
"values = [0.001, 0.0001] # learning rates\n",
|
|
"train(SCA_SRNet, train_gen, valid_gen , train_batch_size, valid_batch_size, valid_ds_size, \\\n",
|
|
" optimizer, boundaries, values, train_interval, valid_interval, max_iter,\\\n",
|
|
" save_interval, LOG_DIR,num_runner_threads, load_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Testing \n",
|
|
"TEST_COVER_DIR = '/media/Cover_TST/'\n",
|
|
"TEST_STEGO_DIR = '/media/Stego_WOW_0.5_TST/'\n",
|
|
"TEST_COVER_BETA_DIR = '/media/Beta_Cover_WOW_0.5_TST/'\n",
|
|
"TEST_STEGO_BETA_DIR = '/media/Beta_Stego_WOW_0.5_TST/'\n",
|
|
"\n",
|
|
"test_batch_size=40\n",
|
|
"LOG_DIR = '/media/LogFiles/SCA_WOW_0.5/' \n",
|
|
"LOAD_DIR = LOG_DIR + 'Model_435000.ckpt' # loading from a specific checkpoint\n",
|
|
"\n",
|
|
"test_gen = partial(gen_valid, \\\n",
|
|
" TEST_COVER_DIR, TEST_STEGO_DIR)\n",
|
|
"\n",
|
|
"test_ds_size = len(glob(TEST_COVER_DIR + '/*')) * 2\n",
|
|
"print 'test_ds_size: %i'%test_ds_size\n",
|
|
"\n",
|
|
"if test_ds_size % test_batch_size != 0:\n",
|
|
" raise ValueError(\"change batch size for testing!\")\n",
|
|
"\n",
|
|
"test_dataset(SCA_SRNet, test_gen, test_batch_size, test_ds_size, LOAD_DIR)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 2",
|
|
"language": "python",
|
|
"name": "python2"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.15rc1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|