Browse Source

Add dataset

master
wchen342 5 years ago
parent
commit
3aa146607b
  1. 13
      Readme.md
  2. 126
      data_processing/extract_images.py
  3. 4
      main_single.py
  4. 1
      src_single/graph_single.py

13
Readme.md

@ -13,10 +13,21 @@ Code for ["SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis"]
## Preparations
- The path to data files needs to be specified in `input_pipeline.py`. The dataset will be released shortly.
- The path to data files needs to be specified in `input_pipeline.py`. See below for detailed information on data files.
- You need to download ["Inception-V4 model"](http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz), unzip it and put the checkpoint under `inception_v4_model`.
## Dataset
Pre-built tfrecord files are available for out of the box training.
- Files for the Sketchy Database can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EtKmg1alDNdIl09WcvtJp_cBFs_7td3wKnb5FUcWZswEmw?e=eBGO6G).
- Files for Augmented Sketchy(i.e. flickr images+edge maps), resized to 256x256 regardless of original aspect ratios, can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EmF7KlhqZ8ZPnpzbTIMDKBoBcjMrezh3X2eS1P_KtWiGCQ?e=BJhFPF).
If you wish to get the original image files:
- The Sketchy Datqabase can be found [here](http://sketchy.eye.gatech.edu/).
- Use `extract_images.py` under `data_processing` to extract images from tfrecord files. You need to specify input and output paths. The extracted images will be sorted by class names.
- Please contact me if you need the original (not resized) Flickr images, since they are too large to upload to any online space.
## Configurations
The model can be trained out of the box, by running `main_single.py`. But there are several places you can change configurations:

126
data_processing/extract_images.py

@ -0,0 +1,126 @@
import os
import cv2
import numpy as np
import tensorflow as tf
datafile_path = "/media/cwl/Data/Programming/Others/acv_p_test1/py3/CycleGAN_sketchy/training_data/flickr_output_new"
image_output_path = "/media/cwl/Data/test/images"
edgemap_output_path = "/media/cwl/Data/test/edges"
def get_paired_input(filenames):
filename_queue = tf.train.string_input_producer(filenames, capacity=512, shuffle=False, num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'ImageNetID': tf.FixedLenFeature([], tf.string),
'SketchID': tf.FixedLenFeature([], tf.int64),
'Category': tf.FixedLenFeature([], tf.string),
'CategoryID': tf.FixedLenFeature([], tf.int64),
'Difficulty': tf.FixedLenFeature([], tf.int64),
'Stroke_Count': tf.FixedLenFeature([], tf.int64),
'WrongPose': tf.FixedLenFeature([], tf.int64),
'Context': tf.FixedLenFeature([], tf.int64),
'Ambiguous': tf.FixedLenFeature([], tf.int64),
'Error': tf.FixedLenFeature([], tf.int64),
'class_id': tf.FixedLenFeature([], tf.int64),
'is_test': tf.FixedLenFeature([], tf.int64),
'image_jpeg': tf.FixedLenFeature([], tf.string),
'sketch_png': tf.FixedLenFeature([], tf.string),
}
)
image = features['image_jpeg']
sketch = features['sketch_png']
# Attributes
category = features['Category']
# Not used
# imagenet_id = features['ImageNetID']
# sketch_id = features['SketchID']
# class_id = features['class_id']
# is_test = features['is_test']
# Stroke_Count = features['Stroke_Count']
# Difficulty = features['Difficulty']
# CategoryID = features['CategoryID']
# WrongPose = features['WrongPose']
# Context = features['Context']
# Ambiguous = features['Ambiguous']
# Error = features['Error']
return image, sketch, category
def build_queue(filenames, batch_size, capacity=1024):
image, sketch, category = get_paired_input(filenames)
images, sketchs, categories = tf.train.batch(
[image, sketch, category],
batch_size=1, capacity=capacity, num_threads=2, allow_smaller_final_batch=True)
return images, sketchs, categories
def extract_images(class_name):
filenames = sorted([os.path.join(datafile_path, f) for f in os.listdir(datafile_path)
if os.path.isfile(os.path.join(datafile_path, f)) and f.startswith(class_name)])
# Make dirs
this_image_path = os.path.join(image_output_path, class_name)
this_edgemap_path = os.path.join(edgemap_output_path, class_name)
if not os.path.isdir(image_output_path) and not os.path.exists(image_output_path):
os.makedirs(image_output_path)
if not os.path.isdir(this_image_path) and not os.path.exists(this_image_path):
os.makedirs(this_image_path)
if not os.path.isdir(edgemap_output_path) and not os.path.exists(edgemap_output_path):
os.makedirs(edgemap_output_path)
if not os.path.isdir(this_edgemap_path) and not os.path.exists(this_edgemap_path):
os.makedirs(this_edgemap_path)
# Read tfrecords
images, sketchs, categories = build_queue(filenames, 64)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
counter = 0
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while True:
try:
raw_jpeg_data, raw_png_data, category_names = sess.run(
[images, sketchs, categories])
filename_appendix = "_%08d" % counter
with open(os.path.join(this_image_path, class_name + filename_appendix + '.jpg'), 'wb') as f:
f.write(raw_jpeg_data[0])
with open(os.path.join(this_edgemap_path, class_name + filename_appendix + '.png'), 'wb') as f:
f.write(raw_png_data[0])
counter += 1
except Exception as e:
print(e.args)
break
if counter % 100 == 0:
print("Now at iteration %d." % counter)
coord.request_stop()
coord.join(threads)
print()
if __name__ == "__main__":
# class_name = "airplane"
filenames = sorted([f for f in os.listdir(datafile_path) if os.path.isfile(os.path.join(datafile_path, f))])
class_names = sorted(list({f.replace('_', '.').split('.', 1)[0] for f in filenames}))
print('Num of classes found: %d' % len(class_names))
for cls in class_names:
extract_images(cls)

4
main_single.py

@ -148,14 +148,14 @@ if __name__ == "__main__":
parser.add_argument('--mode', type=str, default="train", help="train or test")
parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec")
parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file")
parser.add_argument('--batch_size', default=12, type=int, help='Batch size per gpu')
parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu')
parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations")
parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations")
parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter")
parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph")
parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator")
parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator")
parser.add_argument('--num_gpu', default=1, type=int, help="Number of GPUs to use")
parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use")
parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches")
parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256")
parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record")

1
src_single/graph_single.py

@ -250,7 +250,6 @@ def build_single_graph(images, sketches, images_d,
reuse=True, data_format=data_format, output_channel=3)
# Discriminator
# Stage 1
real_disc_out, real_logit = discriminator(images_d, num_classes=num_classes, labels=image_data_class_id_d,
reuse=False, data_format=data_format, scope_name='discriminator')
fake_disc_out, fake_logit = discriminator(image_gens, num_classes=num_classes, labels=image_labels,

Loading…
Cancel
Save