How to fine tune a deep neural network using Tensorflow
Overview
In this tutorial will quickly show the steps to fine tune any neural network using Tensorflow. I assume you already know how to save and upload checkpoints from Tensorflow, tensorflow uses checkpoints files to save variables and operations from a graph. Checkpoints allows anyone to retrain modules with all or partial weights from previews trainings.
In this example we would be using VGG16 you can download the Standard pre-processing for VGG on ImageNet and read VGG paper for more details
1. Pre-processing images
Pre-processing (for both training and validation): 1 Decode the image from jpg format 2 Resize the image so its smaller side is 256 pixels long
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3) # (1)
print("decode_jpeg", image_decoded.shape)
image = tf.cast(image_decoded, tf.float32)
print("cast", image.shape)
smallest_side = 256.0
height, width = tf.shape(image)[0], tf.shape(image)[1]
height = tf.to_float(height)
width = tf.to_float(width)
scale = tf.cond(tf.greater(height, width),
lambda: smallest_side / width,
lambda: smallest_side / height)
new_height = tf.to_int32(height * scale)
new_width = tf.to_int32(width * scale)
resized_image = tf.image.resize_images(image, [new_height, new_width]) # (2)
return resized_image, label
Pre-processing (for training)
3 Take a random 224x224 crop to the scaled image
4 Horizontally flip the image with probability 1/2
5 Substract the per color mean VGG_MEAN
Note: we don’t normalize the data here, as VGG was trained without normalization
def training_preprocess(image, label):
crop_image = tf.random_crop(image, [224, 224, 3]) # (3)
flip_image = tf.image.random_flip_left_right(crop_image) # (4)
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
centered_image = flip_image - means # (5)
return centered_image, label
Preprocessing (for validation)
3 Take a central 224x224 crop to the scaled image
4 Substract the per color mean VGG_MEAN
Note: we don’t normalize the data here, as VGG was trained without normalization
def val_preprocess(image, label):
crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) # (3)
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
centered_image = crop_image - means # (4)
return centered_image, label
2. DATASET CREATION
The tf.contrib.data.Dataset framework uses queues in the background to feed in data to the model. tf.contrib.data.Dataset allow us to load our data with in tensorflow graph. We initialize the dataset with a list of file names and labels, and then apply the preprocessing functions described above. Behind the scenes, queues will load the file names, preprocess them with multiple threads and apply the preprocessing in parallel, and then batch the data
# Training dataset
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
train_dataset = train_dataset.map(_parse_function,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
train_dataset = train_dataset.map(training_preprocess,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle
batched_train_dataset = train_dataset.batch(args.batch_size)
# Validation dataset
val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels))
val_dataset = val_dataset.map(_parse_function,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
val_dataset = val_dataset.map(val_preprocess,
num_threads=args.num_workers, output_buffer_size=args.batch_size)
batched_val_dataset = val_dataset.batch(args.batch_size)
Now we define an iterator that can operator on either dataset. The iterator can be reinitialized by calling: - sess.run(train_init_op) for 1 epoch on the training set - sess.run(val_init_op) for 1 epoch on the valiation set Once this is done, we don’t need to feed any value for images and labels as they are automatically pulled out from the iterator queues.
A reinitializable iterator is defined by its structure. We could use the output_types
and output_shapes
properties of either train_dataset
or validation_dataset
here, because they are compatible.
iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types, batched_train_dataset.output_shapes)
images, labels = iterator.get_next()
train_init_op = iterator.make_initializer(batched_train_dataset)
val_init_op = iterator.make_initializer(batched_val_dataset)
# Indicates whether we are in training or in test mode
is_training = tf.placeholder(tf.bool)
Now that we have set up the data, it’s time to set up the model. For this example, we’ll use VGG-16 pretrained on ImageNet. We will remove the last fully connected layer (fc8) and replace it with our own, with an output size num_classes=8 We will first train the last layer for a few epochs. Then we will train the entire model on our dataset for a few epochs.
Get the pretrained model, specifying the num_classes argument to create a new fully connected replacing the last one, called “vgg_16/fc8” Each model has a different architecture, so “vgg_16/fc8” will change in another model. Here, logits gives us directly the predicted scores we wanted from the images. We pass a scope to initialize “vgg_16/fc8” weights with he_initializer
vgg = tf.contrib.slim.nets.vgg
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=args.weight_decay)):
logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training,
dropout_keep_prob=args.dropout_keep_prob)
# Specify where the model checkpoint is (pertained weights).
model_path = args.model_path
assert(os.path.isfile(model_path))
# Restore only the layers up to fc7 (included)
# Calling function `init_fn(sess)` will load all the pretrained weights.
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)
# Initialization operation from scratch for the new "fc8" layers
# `get_variables` will only return the variables whose name starts with the given pattern
fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
fc8_init = tf.variables_initializer(fc8_variables)
Using tf.losses, any loss is added to the tf.GraphKeys.LOSSES collection We can then call the total loss easily
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
loss = tf.losses.get_total_loss()
# First we want to train only the reinitialized last layer fc8 for a few epochs.
# We run minimize the loss only with respect to the fc8 variables (weight and bias).
fc8_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate1)
fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables)
# Then we want to finetune the entire model for a few epochs.
# We run minimize the loss only with respect to all the variables.
full_optimizer = tf.train.GradientDescentOptimizer(args.learning_rate2)
full_train_op = full_optimizer.minimize(loss)
# Evaluation metrics
prediction = tf.to_int32(tf.argmax(logits, 1))
correct_prediction = tf.equal(prediction, labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
tf.get_default_graph().finalize()
Now that we have built the graph and finalized it, we define the session.
The session is the interface to run the computational graph.
We can call our training operations with sess.run(train_op)
for instance
with tf.Session(graph=graph) as sess:
init_fn(sess) # load the pretrained weights
sess.run(fc8_init) # initialize the new fc8 layer
# Update only the last layer for a few epochs.
for epoch in range(args.num_epochs1):
# Run an epoch over the training data.
print('Update last layer - Starting epoch %d / %d' % (epoch + 1, args.num_epochs1))
# Here we initialize the iterator with the training set.
# This means that we can go through an entire epoch until the iterator becomes empty.
sess.run(train_init_op)
while True:
try:
_ = sess.run(fc8_train_op, {is_training: True})
except tf.errors.OutOfRangeError:
break
# Check accuracy on the train and val sets every epoch.
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
print('Train accuracy: %f' % train_acc)
print('Val accuracy: %f\n' % val_acc)
# Train the entire model for a few more epochs, continuing with the *same* weights.
for epoch in range(args.num_epochs2):
print('Training - Starting epoch %d / %d' % (epoch + 1, args.num_epochs2))
sess.run(train_init_op)
while True:
try:
_ = sess.run(full_train_op, {is_training: True})
except tf.errors.OutOfRangeError:
break
# Check accuracy on the train and val sets every epoch
train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
print('Train accuracy: %f' % train_acc)
print('Val accuracy: %f\n' % val_acc)
saver.save(sess, args.output_checkpoint)