Skip to content

Commit

Permalink
update fcn test data reader
Browse files Browse the repository at this point in the history
  • Loading branch information
qhan1028 committed Jan 19, 2018
1 parent bb34c49 commit cc75e14
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 78 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@

### Quick Usage
* Train
* `python3.5 FCN.py -m train`
* `python3.5 fcn.py -m train`
* Visualize
* `python3.5 FCN.py -m visualize`
* `python3.5 fcn.py -m visualize`
* Test
* `python3.5 FCN.py -m test -tl <test_list>`
* Mat video
* `./mat.sh <video_name>`
* `python3.5 fcn.py -m test -tl <test_list>`
* To see full usage
* `python3.5 FCN.py --help`
* `python3.5 fcn.py --help`
78 changes: 37 additions & 41 deletions fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,31 +168,32 @@ def main():

# Summary
print('====================================================')
tf.summary.image("input_image", image, max_outputs=4)
tf.summary.image("ground_truth", tf.cast(annotation * 255, tf.uint8), max_outputs=4)
tf.summary.image("pred_annotation", tf.cast(pred_annotation * 255, tf.uint8), max_outputs=4)
loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
labels=tf.squeeze(annotation, squeeze_dims=[3]),
name="entropy")))
tf.summary.scalar("train_entropy", loss)

trainable_var = tf.trainable_variables()
if args.debug:
for var in trainable_var:
utils.add_to_regularization_and_summary(var)
train_op = train(loss, trainable_var)
if args.mode != 'test':
tf.summary.image("input_image", image, max_outputs=4)
tf.summary.image("ground_truth", tf.cast(annotation * 255, tf.uint8), max_outputs=4)
tf.summary.image("pred_annotation", tf.cast(pred_annotation * 255, tf.uint8), max_outputs=4)
loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
labels=tf.squeeze(annotation, squeeze_dims=[3]),
name="entropy")))
tf.summary.scalar("train_entropy", loss)

trainable_var = tf.trainable_variables()
if args.debug:
for var in trainable_var:
utils.add_to_regularization_and_summary(var)
train_op = train(loss, trainable_var)

print("> [FCN] Setting up summary op...")
summary_op = tf.summary.merge_all()
print("> [FCN] Setting up summary op...")
summary_op = tf.summary.merge_all()

# Validation summary
val_summary = tf.summary.scalar("validation_entropy", loss)
# Validation summary
val_summary = tf.summary.scalar("validation_entropy", loss)

# Read data
print("> [FCN] Setting up image reader...")
train_records, valid_records = read_dataset(args.data_dir)
print('> [FCN] Train len:', len(train_records))
print('> [FCN] Val len:', len(valid_records))
# Read data
print("> [FCN] Setting up image reader...")
train_records, valid_records = read_dataset(args.data_dir)
print('> [FCN] Train len:', len(train_records))
print('> [FCN] Val len:', len(valid_records))

t = timer.Timer() # Qhan's timer

Expand All @@ -201,12 +202,11 @@ def main():
image_options = {'resize': True, 'resize_height': IMAGE_HEIGHT, 'resize_width': IMAGE_WIDTH}
if args.mode == 'train':
t.tic(); train_dataset_reader = dataset.BatchDatset(train_records, image_options, mode='train')
load_time = t.toc()
print('> [FCN] Train data set loaded. %.4f ms' % (load_time))
print('> [FCN] Train data set loaded. %.4f ms' % t.toc())
t.tic(); validation_dataset_reader = dataset.BatchDatset(valid_records, image_options, mode='val')
load_time = t.toc()
print('> [FCN] Validation data set loaded. %.4f ms' % (load_time))
print('> [FCN] Validation data set loaded. %.4f ms' % t.toc())

# Setup Session
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.90, allow_growth=True)
sess = tf.Session( config=tf.ConfigProto(gpu_options=gpu_options) )

Expand All @@ -217,7 +217,7 @@ def main():

print("> [FCN] Initialize variables... ", flush=True, end='')
t.tic(); sess.run(tf.global_variables_initializer())
print('%.4f ms' % (t.toc()))
print('%.4f ms' % t.toc())

t.tic()
ckpt = tf.train.get_checkpoint_state(args.logs_dir)
Expand Down Expand Up @@ -259,46 +259,42 @@ def main():
if itr % 100 == 0 and itr != 0:
valid_images, valid_annotations = validation_dataset_reader.next_batch(args.batch_size * 2)
val_feed_dict = { image: valid_images, annotation: valid_annotations, keep_probability: 1.0}
t.tic()
val_loss, val_str = sess.run([loss, val_summary], feed_dict=val_feed_dict)
val_time = t.toc()

t.tic(); val_loss, val_str = sess.run([loss, val_summary], feed_dict=val_feed_dict)
print("[%6d], Validation_loss: %g, %.4f ms" % (itr, val_loss, t.toc()))

summary_writer.add_summary(val_str, itr)
print("[%6d], Validation_loss: %g, %.4f ms" % (itr, val_loss, val_time))

if itr % 1000 == 0 and itr != 0:
saver.save(sess, args.logs_dir + "model.ckpt", itr)

elif args.mode == 'visualize':
for itr in range(20):
valid_images, valid_annotations = validation_dataset_reader.get_random_batch(1)
t.tic()
pred = sess.run(pred_annotation, feed_dict={image: valid_images, keep_probability: 1.0})
val_time = t.toc()

t.tic(); pred = sess.run(pred_annotation, feed_dict={image: valid_images, keep_probability: 1.0})
print("> [FCN] Saved image: %d, %.4f ms" % (itr, t.toc()))

valid_annotations = np.squeeze(valid_annotations, axis=3)
pred = np.squeeze(pred, axis=3)

utils.save_image(valid_images[0].astype(np.uint8), args.res_dir, name="inp_" + str(itr))
utils.save_image(valid_annotations[0].astype(np.uint8), args.res_dir, name="gt_" + str(itr))
utils.save_image(pred[0].astype(np.uint8), args.res_dir, name="pred_" + str(itr))
print("> [FCN] Saved image: %d, %.4f ms" % (itr, val_time))

elif args.mode == 'test':
testlist = args.testlist
images, names, (H, W) = read_test_data(testlist, IMAGE_HEIGHT, IMAGE_WIDTH)
images, names, (H, W) = read_test_data(args.test_dir, IMAGE_HEIGHT, IMAGE_WIDTH)
for i, (im, name) in enumerate(zip(images, names)):

t.tic()
pred = sess.run(pred_annotation, feed_dict={image: im.reshape((1,) + im.shape), keep_probability: 1.0})
test_time = t.toc()
t.tic(); pred = sess.run(pred_annotation, feed_dict={image: im.reshape((1,) + im.shape), keep_probability: 1.0})
print('> [FCN] Test: %d,' % (i) + ' Name: ' + name + ', %.4f ms' % t.toc())

pred = pred.reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
if args.video:
save_video_image(im, pred, args.res_dir + '/pred_%05d' % (i) + '.png', H, W)
else:
misc.imsave(args.res_dir + '/inp_%d' % (i) + '.png', im.astype(np.uint8))
misc.imsave(args.res_dir + '/pred_%d' % (i) + '.png', pred.astype(np.uint8))
print('> [FCN] Img: %d,' % (i) + ' Name: ' + name + ', %.4f ms' % test_time)

else:
pass
Expand Down
56 changes: 25 additions & 31 deletions reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from __future__ import print_function

import argparse
import numpy as np
import os
import os.path as osp
import cv2

# Argument parser
def parse_args():
Expand All @@ -20,8 +24,8 @@ def parse_args():
parser.add_argument('-ld', '--logs-dir', metavar='DIR', default='logs', nargs='?', help='Path to logs directory.')
parser.add_argument('-rd', '--res-dir', metavar='DIR', default='res', nargs='?', help='Path to result directory.')
parser.add_argument('-md', '--model-dir', metavar='DIR', default='Model_zoo', nargs='?', help='Path to vgg pretrained model.')
parser.add_argument('-td', '--test-dir', metavar='DIR', default='test', nargs='?', help='Test directory.')
parser.add_argument('--debug', action='store_true', default=True, help='Debug mode.')
parser.add_argument('--testlist', metavar='FILE', default='testlist.txt', nargs='?', help='Test list for testing.')
parser.add_argument('-v', '--video', action='store_true', default=False, help='Resize back to original size.')
args = parser.parse_args()

Expand All @@ -37,41 +41,31 @@ def parse_args():
return args


import numpy as np
import scipy.misc as misc
from PIL import Image

'''
:filename: test data list
:resize_size:
:return: np array of images, names, original size
'''
def read_test_data(listname, height, width):
def read_test_data(testdir, height, width):

if testdir[-1] != '/': testdir += '/'
oh, ow = None, None
images, names = [], []

with open(listname, 'r') as f:

image_dir = f.readline()[:-1]

for line in f:

name = line[:-1]
path = image_dir + '/' + name
print('\rpath: ' + path, end='', flush=True)
image = Image.open(path)
(w, h) = image.size
#max_edge = max(w, h)
#image = np.array( image.crop((0, 0, max_edge, max_edge)) )
resized_image = misc.imresize(image, [height, width], interp='nearest')

for filename in os.listdir(testdir):
name, ext = osp.splitext(filename)
if ext in ['.jpg', 'png', 'gif']:
print('\r> image:', filename, end='')
im = cv2.imread(testdir + filename)

if oh is None:
oh, ow = im.shape[:2]

im_resized = cv2.resize(im, (width, height), cv2.INTER_CUBIC)
im_rgb = cv2.cvtColor(im_resized, cv2.COLOR_BGR2RGB)
names.append(name)
images.append(resized_image)

print('')
#h, w, _ = image.shape
images.append(im_rgb)
else:
print('> skip:', filename)

print('')

return np.array(images), np.array(names), (h, w)
return np.array(images), np.array(names), (oh, ow)


#
Expand Down

0 comments on commit cc75e14

Please sign in to comment.