{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import os\n", "import sys\n", "os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '1' # set a GPU (with GPU Number)\n", "home = os.path.expanduser(\"~\")\n", "sys.path.append(home + '/tflib/') # path for 'tflib' folder\n", "from SRNet import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_batch_size = 32\n", "valid_batch_size = 40\n", "max_iter = 500000\n", "train_interval=100\n", "valid_interval=5000\n", "save_interval=5000\n", "num_runner_threads=10\n", "\n", "# Cover and Stego directories for training and validation. For the spatial domain put cover and stego images in their \n", "# corresponding direcotries. For the JPEG domain, decompress images to the spatial domain without rounding to integers and \n", "# save them as '.mat' files with variable name \"im\". Put the '.mat' files in thier corresponding directoroies. Make sure \n", "# all mat files in the directories can be loaded in Python without any errors.\n", "\n", "TRAIN_COVER_DIR = '/media/TRN/Cover/'\n", "TRAIN_STEGO_DIR = '/media/TRN/JUNI_75_04/'\n", "\n", "VALID_COVER_DIR = '/media/VAL/Cover/'\n", "VALID_STEGO_DIR = '/media/VAL/JUNI_75_04/'\n", " \n", "train_gen = partial(gen_flip_and_rot, \\\n", " TRAIN_COVER_DIR, TRAIN_STEGO_DIR ) \n", "valid_gen = partial(gen_valid, \\\n", " VALID_COVER_DIR, VALID_STEGO_DIR)\n", "\n", "LOG_DIR = '/media/LogFiles/JUNI_75_04' # path for a log direcotry \n", "# load_path = LOG_DIR + 'Model_460000.ckpt' # continue training from a specific checkpoint\n", "load_path=None # training from scratch\n", "\n", "if not os.path.exists(LOG_DIR):\n", " os.makedirs(LOG_DIR)\n", "\n", "train_ds_size = len(glob(TRAIN_COVER_DIR + '/*')) * 2\n", "valid_ds_size = len(glob(VALID_COVER_DIR +'/*')) * 2\n", "print 'train_ds_size: %i'%train_ds_size\n", "print 'valid_ds_size: %i'%valid_ds_size\n", "\n", "if valid_ds_size % valid_batch_size != 0:\n", " raise ValueError(\"change batch size for validation\")\n", " \n", "optimizer = AdamaxOptimizer\n", "boundaries = [400000] # learning rate adjustment at iteration 400K\n", "values = [0.001, 0.0001] # learning rates\n", "\n", "train(SRNet, train_gen, valid_gen , train_batch_size, valid_batch_size, valid_ds_size, \\\n", " optimizer, boundaries, values, train_interval, valid_interval, max_iter,\\\n", " save_interval, LOG_DIR,num_runner_threads, load_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Testing \n", "# Cover and Stego directories for testing\n", "TEST_COVER_DIR = '/media/TST/Cover/'\n", "TEST_STEGO_DIR = '/media/TST/JUNI_75_04/'\n", "\n", "test_batch_size=40\n", "LOG_DIR = '/media/LogFiles/JUNI_75_04/' \n", "LOAD_CKPT = LOG_DIR + 'Model_435000.ckpt' # loading from a specific checkpoint\n", "\n", "test_gen = partial(gen_valid, \\\n", " TEST_COVER_DIR, TEST_STEGO_DIR)\n", "\n", "test_ds_size = len(glob(TEST_COVER_DIR + '/*')) * 2\n", "print 'test_ds_size: %i'%test_ds_size\n", "\n", "if test_ds_size % test_batch_size != 0:\n", " raise ValueError(\"change batch size for testing!\")\n", "\n", "test_dataset(SRNet, test_gen, test_batch_size, test_ds_size, LOAD_CKPT)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15rc1" } }, "nbformat": 4, "nbformat_minor": 2 }