diff --git a/Readme.md b/Readme.md index 39271d5..030796a 100644 --- a/Readme.md +++ b/Readme.md @@ -13,10 +13,21 @@ Code for ["SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis"] ## Preparations -- The path to data files needs to be specified in `input_pipeline.py`. The dataset will be released shortly. +- The path to data files needs to be specified in `input_pipeline.py`. See below for detailed information on data files. - You need to download ["Inception-V4 model"](http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz), unzip it and put the checkpoint under `inception_v4_model`. +## Dataset +Pre-built tfrecord files are available for out of the box training. +- Files for the Sketchy Database can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EtKmg1alDNdIl09WcvtJp_cBFs_7td3wKnb5FUcWZswEmw?e=eBGO6G). +- Files for Augmented Sketchy(i.e. flickr images+edge maps), resized to 256x256 regardless of original aspect ratios, can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EmF7KlhqZ8ZPnpzbTIMDKBoBcjMrezh3X2eS1P_KtWiGCQ?e=BJhFPF). + +If you wish to get the original image files: +- The Sketchy Datqabase can be found [here](http://sketchy.eye.gatech.edu/). +- Use `extract_images.py` under `data_processing` to extract images from tfrecord files. You need to specify input and output paths. The extracted images will be sorted by class names. +- Please contact me if you need the original (not resized) Flickr images, since they are too large to upload to any online space. + + ## Configurations The model can be trained out of the box, by running `main_single.py`. But there are several places you can change configurations: diff --git a/data_processing/extract_images.py b/data_processing/extract_images.py new file mode 100644 index 0000000..9649248 --- /dev/null +++ b/data_processing/extract_images.py @@ -0,0 +1,126 @@ +import os +import cv2 +import numpy as np +import tensorflow as tf + +datafile_path = "/media/cwl/Data/Programming/Others/acv_p_test1/py3/CycleGAN_sketchy/training_data/flickr_output_new" +image_output_path = "/media/cwl/Data/test/images" +edgemap_output_path = "/media/cwl/Data/test/edges" + + +def get_paired_input(filenames): + filename_queue = tf.train.string_input_producer(filenames, capacity=512, shuffle=False, num_epochs=1) + reader = tf.TFRecordReader() + + _, serialized_example = reader.read(filename_queue) + + features = tf.parse_single_example( + serialized_example, + features={ + 'ImageNetID': tf.FixedLenFeature([], tf.string), + 'SketchID': tf.FixedLenFeature([], tf.int64), + 'Category': tf.FixedLenFeature([], tf.string), + 'CategoryID': tf.FixedLenFeature([], tf.int64), + 'Difficulty': tf.FixedLenFeature([], tf.int64), + 'Stroke_Count': tf.FixedLenFeature([], tf.int64), + 'WrongPose': tf.FixedLenFeature([], tf.int64), + 'Context': tf.FixedLenFeature([], tf.int64), + 'Ambiguous': tf.FixedLenFeature([], tf.int64), + 'Error': tf.FixedLenFeature([], tf.int64), + 'class_id': tf.FixedLenFeature([], tf.int64), + 'is_test': tf.FixedLenFeature([], tf.int64), + 'image_jpeg': tf.FixedLenFeature([], tf.string), + 'sketch_png': tf.FixedLenFeature([], tf.string), + } + ) + + image = features['image_jpeg'] + sketch = features['sketch_png'] + + # Attributes + category = features['Category'] + # Not used + # imagenet_id = features['ImageNetID'] + # sketch_id = features['SketchID'] + # class_id = features['class_id'] + # is_test = features['is_test'] + # Stroke_Count = features['Stroke_Count'] + # Difficulty = features['Difficulty'] + # CategoryID = features['CategoryID'] + # WrongPose = features['WrongPose'] + # Context = features['Context'] + # Ambiguous = features['Ambiguous'] + # Error = features['Error'] + + return image, sketch, category + + +def build_queue(filenames, batch_size, capacity=1024): + image, sketch, category = get_paired_input(filenames) + + images, sketchs, categories = tf.train.batch( + [image, sketch, category], + batch_size=1, capacity=capacity, num_threads=2, allow_smaller_final_batch=True) + + return images, sketchs, categories + + +def extract_images(class_name): + filenames = sorted([os.path.join(datafile_path, f) for f in os.listdir(datafile_path) + if os.path.isfile(os.path.join(datafile_path, f)) and f.startswith(class_name)]) + + # Make dirs + this_image_path = os.path.join(image_output_path, class_name) + this_edgemap_path = os.path.join(edgemap_output_path, class_name) + if not os.path.isdir(image_output_path) and not os.path.exists(image_output_path): + os.makedirs(image_output_path) + if not os.path.isdir(this_image_path) and not os.path.exists(this_image_path): + os.makedirs(this_image_path) + if not os.path.isdir(edgemap_output_path) and not os.path.exists(edgemap_output_path): + os.makedirs(edgemap_output_path) + if not os.path.isdir(this_edgemap_path) and not os.path.exists(this_edgemap_path): + os.makedirs(this_edgemap_path) + + # Read tfrecords + images, sketchs, categories = build_queue(filenames, 64) + + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + counter = 0 + + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + while True: + try: + raw_jpeg_data, raw_png_data, category_names = sess.run( + [images, sketchs, categories]) + filename_appendix = "_%08d" % counter + with open(os.path.join(this_image_path, class_name + filename_appendix + '.jpg'), 'wb') as f: + f.write(raw_jpeg_data[0]) + with open(os.path.join(this_edgemap_path, class_name + filename_appendix + '.png'), 'wb') as f: + f.write(raw_png_data[0]) + + counter += 1 + except Exception as e: + print(e.args) + break + + if counter % 100 == 0: + print("Now at iteration %d." % counter) + + coord.request_stop() + coord.join(threads) + print() + + +if __name__ == "__main__": + # class_name = "airplane" + filenames = sorted([f for f in os.listdir(datafile_path) if os.path.isfile(os.path.join(datafile_path, f))]) + class_names = sorted(list({f.replace('_', '.').split('.', 1)[0] for f in filenames})) + print('Num of classes found: %d' % len(class_names)) + + for cls in class_names: + extract_images(cls) diff --git a/main_single.py b/main_single.py index 26b8b90..a5112a4 100644 --- a/main_single.py +++ b/main_single.py @@ -148,14 +148,14 @@ if __name__ == "__main__": parser.add_argument('--mode', type=str, default="train", help="train or test") parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec") parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file") - parser.add_argument('--batch_size', default=12, type=int, help='Batch size per gpu') + parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu') parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations") parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations") parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter") parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph") parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator") parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator") - parser.add_argument('--num_gpu', default=1, type=int, help="Number of GPUs to use") + parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use") parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches") parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256") parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record") diff --git a/src_single/graph_single.py b/src_single/graph_single.py index 830bb22..cdda43e 100644 --- a/src_single/graph_single.py +++ b/src_single/graph_single.py @@ -250,7 +250,6 @@ def build_single_graph(images, sketches, images_d, reuse=True, data_format=data_format, output_channel=3) # Discriminator - # Stage 1 real_disc_out, real_logit = discriminator(images_d, num_classes=num_classes, labels=image_data_class_id_d, reuse=False, data_format=data_format, scope_name='discriminator') fake_disc_out, fake_logit = discriminator(image_gens, num_classes=num_classes, labels=image_labels,