wchen342 5 years ago
  1. 13
  2. 126
  3. 4
  4. 1


## Preparations
- The path to data files needs to be specified in ``. The dataset will be released shortly.
- The path to data files needs to be specified in ``. See below for detailed information on data files.
- You need to download ["Inception-V4 model"](, unzip it and put the checkpoint under `inception_v4_model`.
## Dataset
Pre-built tfrecord files are available for out of the box training.
- Files for the Sketchy Database can be found [here](
- Files for Augmented Sketchy(i.e. flickr images+edge maps), resized to 256x256 regardless of original aspect ratios, can be found [here](
If you wish to get the original image files:
- The Sketchy Datqabase can be found [here](
- Use `` under `data_processing` to extract images from tfrecord files. You need to specify input and output paths. The extracted images will be sorted by class names.
- Please contact me if you need the original (not resized) Flickr images, since they are too large to upload to any online space.
## Configurations
The model can be trained out of the box, by running ``. But there are several places you can change configurations:


import os
import cv2
import numpy as np
import tensorflow as tf
datafile_path = "/media/cwl/Data/Programming/Others/acv_p_test1/py3/CycleGAN_sketchy/training_data/flickr_output_new"
image_output_path = "/media/cwl/Data/test/images"
edgemap_output_path = "/media/cwl/Data/test/edges"
def get_paired_input(filenames):
filename_queue = tf.train.string_input_producer(filenames, capacity=512, shuffle=False, num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example =
features = tf.parse_single_example(
'ImageNetID': tf.FixedLenFeature([], tf.string),
'SketchID': tf.FixedLenFeature([], tf.int64),
'Category': tf.FixedLenFeature([], tf.string),
'CategoryID': tf.FixedLenFeature([], tf.int64),
'Difficulty': tf.FixedLenFeature([], tf.int64),
'Stroke_Count': tf.FixedLenFeature([], tf.int64),
'WrongPose': tf.FixedLenFeature([], tf.int64),
'Context': tf.FixedLenFeature([], tf.int64),
'Ambiguous': tf.FixedLenFeature([], tf.int64),
'Error': tf.FixedLenFeature([], tf.int64),
'class_id': tf.FixedLenFeature([], tf.int64),
'is_test': tf.FixedLenFeature([], tf.int64),
'image_jpeg': tf.FixedLenFeature([], tf.string),
'sketch_png': tf.FixedLenFeature([], tf.string),
image = features['image_jpeg']
sketch = features['sketch_png']
# Attributes
category = features['Category']
# Not used
# imagenet_id = features['ImageNetID']
# sketch_id = features['SketchID']
# class_id = features['class_id']
# is_test = features['is_test']
# Stroke_Count = features['Stroke_Count']
# Difficulty = features['Difficulty']
# CategoryID = features['CategoryID']
# WrongPose = features['WrongPose']
# Context = features['Context']
# Ambiguous = features['Ambiguous']
# Error = features['Error']
return image, sketch, category
def build_queue(filenames, batch_size, capacity=1024):
image, sketch, category = get_paired_input(filenames)
images, sketchs, categories = tf.train.batch(
[image, sketch, category],
batch_size=1, capacity=capacity, num_threads=2, allow_smaller_final_batch=True)
return images, sketchs, categories
def extract_images(class_name):
filenames = sorted([os.path.join(datafile_path, f) for f in os.listdir(datafile_path)
if os.path.isfile(os.path.join(datafile_path, f)) and f.startswith(class_name)])
# Make dirs
this_image_path = os.path.join(image_output_path, class_name)
this_edgemap_path = os.path.join(edgemap_output_path, class_name)
if not os.path.isdir(image_output_path) and not os.path.exists(image_output_path):
if not os.path.isdir(this_image_path) and not os.path.exists(this_image_path):
if not os.path.isdir(edgemap_output_path) and not os.path.exists(edgemap_output_path):
if not os.path.isdir(this_edgemap_path) and not os.path.exists(this_edgemap_path):
# Read tfrecords
images, sketchs, categories = build_queue(filenames, 64)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
counter = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while True:
raw_jpeg_data, raw_png_data, category_names =
[images, sketchs, categories])
filename_appendix = "_%08d" % counter
with open(os.path.join(this_image_path, class_name + filename_appendix + '.jpg'), 'wb') as f:
with open(os.path.join(this_edgemap_path, class_name + filename_appendix + '.png'), 'wb') as f:
counter += 1
except Exception as e:
if counter % 100 == 0:
print("Now at iteration %d." % counter)
if __name__ == "__main__":
# class_name = "airplane"
filenames = sorted([f for f in os.listdir(datafile_path) if os.path.isfile(os.path.join(datafile_path, f))])
class_names = sorted(list({f.replace('_', '.').split('.', 1)[0] for f in filenames}))
print('Num of classes found: %d' % len(class_names))
for cls in class_names:


parser.add_argument('--mode', type=str, default="train", help="train or test")
parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec")
parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file")
parser.add_argument('--batch_size', default=12, type=int, help='Batch size per gpu')
parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu')
parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations")
parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations")
parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter")
parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph")
parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator")
parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator")
parser.add_argument('--num_gpu', default=1, type=int, help="Number of GPUs to use")
parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use")
parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches")
parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256")
parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record")


reuse=True, data_format=data_format, output_channel=3)
# Discriminator
# Stage 1
real_disc_out, real_logit = discriminator(images_d, num_classes=num_classes, labels=image_data_class_id_d,
reuse=False, data_format=data_format, scope_name='discriminator')
fake_disc_out, fake_logit = discriminator(image_gens, num_classes=num_classes, labels=image_labels,