Code for paper "SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

225 lines
8.9 KiB

import os
import sys
import csv
import numpy as np
import scipy.io
import scipy.misc as spm
import cv2
from scipy import ndimage
import tensorflow as tf
from tensorflow.python.framework import ops
def showImg(img):
cv2.imshow("test", img)
cv2.waitKey(-1)
def dense_to_one_hot(labels_dense, num_classes):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes), dtype=np.int32)
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
classes_info = '../data_processing/classes.csv'
photo_folder = '../Datasets/Sketchy/rendered_256x256/256x256/photo/tx_000000000000'
sketch_folder = '../Datasets/Sketchy/rendered_256x256/256x256/sketch/tx_000000000000'
info_dir = '../Datasets/Sketchy/info'
data_dir = '../tfrecords/sketchy'
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False,
intra_op_parallelism_threads=4)
def check_repeat(seq):
seen = set()
seen_add = seen.add
seen_twice = set(x for x in seq if x in seen or seen_add(x))
return list(seen_twice)
def build_graph():
photo_filename = tf.placeholder(dtype=tf.string, shape=())
label_filename = tf.placeholder(dtype=tf.string, shape=())
photo = tf.read_file(photo_filename)
label = tf.read_file(label_filename)
photo_decoded = tf.image.decode_jpeg(photo, fancy_upscaling=True)
label_decoded = tf.image.decode_png(label)
# Encode 64x64
photo_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 3))
label_input = tf.placeholder(dtype=tf.uint8, shape=(256, 256, 1))
label_small_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 1))
photo_stream = tf.image.encode_jpeg(photo_input, quality=95, progressive=False,
optimize_size=False, chroma_downsampling=False)
label_stream = tf.image.encode_png(label_input, compression=7)
label_small_stream = tf.image.encode_png(label_small_input, compression=7)
return photo_filename, label_filename, photo, label, photo_decoded, label_decoded, photo_input, label_input,\
label_small_input, photo_stream, label_stream, label_small_stream
def read_csv(filename):
with open(filename) as csvfile:
reader = csv.DictReader(csvfile)
l = list(reader)
return l
def read_txt(filename):
with open(filename) as txtfile:
lines = txtfile.readlines()
return [l[:-1] for l in lines]
def split_csvlist(stat_info):
cat = list(set([item['Category'] for item in stat_info]))
l = []
for c in cat:
li = [item for item in stat_info if item['Category'] == c]
l.append(li)
return cat, l
def binarize(sketch, threshold=245):
sketch[sketch < threshold] = 0
sketch[sketch >= threshold] = 255
return sketch
def write_image_data():
csv_file = os.path.join(info_dir, 'stats.csv')
stat_info = read_csv(csv_file)
classes = read_csv(classes_info)
classes_ids = [item['Name'] for item in classes]
test_list = read_txt(os.path.join(info_dir, 'testset.txt'))
invalid_notations = ['invalid-ambiguous.txt', 'invalid-context.txt', 'invalid-error.txt', 'invalid-pose.txt']
invalid_files = []
for txtfile in invalid_notations:
cur_path = os.path.join(info_dir, txtfile)
files = read_txt(cur_path)
files = [f[:-1] for f in files]
invalid_files.extend(files)
path_image = photo_folder
path_label = sketch_folder
dirs, stats = split_csvlist(stat_info)
photo_filename, label_filename, photo, label, photo_decoded, label_decoded, photo_input, label_input, \
label_small_input, photo_stream, label_stream, label_small_stream = build_graph()
assert len(dirs) == len(stats)
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(len(dirs)):
dir = dirs[i].replace(' ', '_')
print(dir)
class_id = classes_ids.index(dir)
stat = stats[i]
writer = tf.python_io.TFRecordWriter(os.path.join(data_dir, dir + '.tfrecord'))
cur_photo_path = os.path.join(path_image, dir)
cur_label_path = os.path.join(path_label, dir)
num_label = len(stat)
# photo_files = [f for f in os.listdir(cur_photo_path) if os.path.isfile(os.path.join(cur_photo_path, f))]
# label_files = [f for f in os.listdir(cur_label_path) if os.path.isfile(os.path.join(cur_label_path, f))]
for j in range(num_label):
if j % 500 == 499:
print(j)
item = stat[j]
ImageNetID = item['ImageNetID']
SketchID = int(item['SketchID'])
Category = item['Category']
CategoryID = int(item['CategoryID'])
Difficulty = int(item['Difficulty'])
Stroke_Count = int(item['Stroke_Count'])
WrongPose = int(item['WrongPose?'])
Context = int(item['Context?'])
Ambiguous = int(item['Ambiguous?'])
Error = int(item['Error?'])
if os.path.join(dir, ImageNetID + '.jpg') in test_list:
IsTest = 1
else:
IsTest = 0
# print(os.path.join(cur_photo_path, ImageNetID + '.jpg'))
# print(os.path.join(cur_label_path, ImageNetID + '-' + str(SketchID) + '.png'))
out_image, out_image_decoded = sess.run([photo, photo_decoded], feed_dict={
photo_filename: os.path.join(cur_photo_path, ImageNetID + '.jpg')})
out_label, out_label_decoded = sess.run([label, label_decoded], feed_dict={
label_filename: os.path.join(cur_label_path, ImageNetID + '-' + str(SketchID) + '.png')})
# Resize
out_image_decoded_small = cv2.resize(out_image_decoded, (64, 64), interpolation=cv2.INTER_AREA)
out_label_decoded = (np.sum(out_label_decoded.astype(np.float64), axis=2)/3).astype(np.uint8)
out_label_decoded_small = cv2.resize(out_label_decoded, (64, 64), interpolation=cv2.INTER_AREA)
# Distance map
out_dist_map = ndimage.distance_transform_edt(binarize(out_label_decoded))
out_dist_map = (out_dist_map / out_dist_map.max() * 255.).astype(np.uint8)
out_dist_map_small = ndimage.distance_transform_edt(binarize(out_label_decoded_small))
out_dist_map_small = (out_dist_map_small / out_dist_map_small.max() * 255.).astype(np.uint8)
# Stream
image_string_small, label_string_small = sess.run([photo_stream, label_small_stream], feed_dict={
photo_input: out_image_decoded_small, label_small_input: out_label_decoded_small.reshape((64, 64, 1))
})
dist_map_string = sess.run(label_stream, feed_dict={label_input: out_dist_map.reshape((256, 256, 1))})
dist_map_string_small = sess.run(label_small_stream, feed_dict={
label_small_input: out_dist_map_small.reshape((64, 64, 1))})
example = tf.train.Example(features=tf.train.Features(feature={
'ImageNetID': _bytes_feature(ImageNetID.encode('utf-8')),
'SketchID': _int64_feature(SketchID),
'Category': _bytes_feature(Category.encode('utf-8')),
'CategoryID': _int64_feature(CategoryID),
'Difficulty': _int64_feature(Difficulty),
'Stroke_Count': _int64_feature(Stroke_Count),
'WrongPose': _int64_feature(WrongPose),
'Context': _int64_feature(Context),
'Ambiguous': _int64_feature(Ambiguous),
'Error': _int64_feature(Error),
'is_test': _int64_feature(IsTest),
'class_id': _int64_feature(class_id),
'image_jpeg': _bytes_feature(out_image),
'image_small_jpeg': _bytes_feature(image_string_small),
'sketch_png': _bytes_feature(out_label),
'sketch_small_png': _bytes_feature(label_string_small),
'dist_map_png': _bytes_feature(dist_map_string),
'dist_map_small_png': _bytes_feature(dist_map_string_small),
}))
writer.write(example.SerializeToString())
# coord.request_stop()
# coord.join(threads)
writer.close()
write_image_data()