Code for paper "SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

191 lines
7.5 KiB

import argparse
import importlib
import os
import sys
import shutil
import json
import tensorflow as tf
from time import gmtime, strftime
src_dir = './src_single'
def launch_training(**kwargs):
# Deal with file and paths
appendix = kwargs["resume_from"]
if appendix is None or appendix == '':
cur_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
log_dir = './log_skgan_' + cur_time
ckpt_dir = './ckpt_skgan_' + cur_time
if not os.path.isdir(log_dir) and not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.isdir(ckpt_dir) and not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
# copy current script in src folder to log dir for record
if not os.path.exists(src_dir) or not os.path.isdir(src_dir):
print("src folder does not exist.")
return
else:
for file in os.listdir(src_dir):
if file.endswith(".py"):
shutil.copy(os.path.join(src_dir, file), log_dir)
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
appendix = cur_time
kwargs["resume_from"] = appendix
kwargs["iter_from"] = 0
# Save parameters
with open(os.path.join(log_dir, 'param_%d.json' % 0), 'w') as fp:
json.dump(kwargs, fp, indent=4)
sys.path.append(src_dir)
entry_point_module = kwargs['entry_point']
from config import Config
Config.set_from_dict(kwargs)
print("Launching new train: %s" % cur_time)
else:
if len(appendix.split('-')) != 6:
print("Invalid resume folder")
return
log_dir = './log_skgan_' + appendix
ckpt_dir = './ckpt_skgan_' + appendix
# Get last parameters (recover entry point module name)
json_files = [f for f in os.listdir(log_dir) if
os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json']
iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files])
with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp:
params = json.load(fp)
entry_point_module = params['entry_point']
# Recover parameters
_ignored = ['num_gpu', 'iter_from']
for k, v in params.items():
if k not in _ignored:
kwargs[k] = v
sys.path.append(log_dir)
# Get latest checkpoint filename
# if stage == 1:
# ckpt_file = tf.train.latest_checkpoint(stage_1_ckpt_dir)
# elif stage == 2:
ckpt_file = tf.train.latest_checkpoint(ckpt_dir)
if ckpt_file is None:
raise RuntimeError
else:
iter_from = int(os.path.split(ckpt_file)[1].split('-')[1]) + 1
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
kwargs['iter_from'] = iter_from
# Save new set of parameters
with open(os.path.join(log_dir, 'param_%d.json' % iter_from), 'w') as fp:
kwargs['entry_point'] = entry_point_module
json.dump(kwargs, fp, indent=4)
from config import Config
Config.set_from_dict(kwargs)
print("Launching train from checkpoint: %s" % appendix)
# Launch train
train_module = importlib.import_module(entry_point_module)
# from train_paired_aug_multi_gpu import train
status = train_module.train(**kwargs)
return status, appendix
def launch_test(**kwargs):
# Deal with file and paths
appendix = kwargs["resume_from"]
if appendix is None or appendix == '' or len(appendix.split('-')) != 6:
print("Invalid resume folder")
return
log_dir = './log_skgan_' + appendix
ckpt_dir = './ckpt_skgan_' + appendix
sys.path.append(log_dir)
# Get latest checkpoint filename
kwargs['log_dir'] = log_dir
kwargs['ckpt_dir'] = ckpt_dir
# Get last parameters (recover entry point module name)
# Assuming last json file
json_files = [f for f in os.listdir(log_dir) if
os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json']
iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files])
with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp:
params = json.load(fp)
entry_point_module = params['entry_point']
# Recover parameters
_ignored = ["num_gpu", 'iter_from']
for k, v in params.items():
if k not in _ignored:
kwargs[k] = v
from config import Config
Config.set_from_dict(kwargs)
print("Launching test from checkpoint: %s" % appendix)
# Launch test
train_module = importlib.import_module(entry_point_module)
train_module.test(**kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train or Test model')
parser.add_argument('--mode', type=str, default="train", help="train or test")
parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec")
parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file")
parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu')
parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations")
parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations")
parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter")
parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph")
parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator")
parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator")
parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use")
parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches")
parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256")
parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record")
args = parser.parse_args()
assert args.optimizer in ["RMSprop", "Adam", "AdaDelta", "AdaGrad"], "Unsupported optimizer"
# Set default params
d_params = {"resume_from": args.resume_from,
"entry_point": args.entry_point,
"batch_size": args.batch_size,
"max_iter_step": args.max_iter_step,
"disc_iterations": args.disc_iterations,
"ld": args.ld,
"optimizer": args.optimizer,
"lr_G": args.lr_G,
"lr_D": args.lr_D,
"num_gpu": args.num_gpu,
"distance_map": args.distance_map,
"small_img": args.small_img,
"extra_info": args.extra_info,
}
if args.mode == 'train':
# Launch training
status, appendix = launch_training(**d_params)
while status == -1: # NaN during training
print("Training ended with status -1. Restarting..")
d_params["resume_from"] = appendix
status = launch_training(**d_params)
elif args.mode == 'test':
launch_test(**d_params)