Fine Tuning
Deep networks need a lot of data to train. What can you do when you don't have much?
- Fine Tuning, where we adapt one network to some new data
Fine tuning will re-use a set of network weights for a new task. This can be seen simply as starting with a much more advanced set of weights over random initialization.
When we fine tune, we may wish to remove or change some layers. This may be because our new task has a different size output, or we've changed tasks entirely, going from for example classification to regression.
Keras Example
We will use a new set of data to fine-tune the network. This is just like compiling and fitting the network as per usual. But there are a couple of things to note:
- Generally, we will use a smaller learning rate for fine-tuning, the network is already trained after all. We don't want to undo some of the good work of the old model.
- We often use SGD, as it will make small weights update.
Preparation
model = keras.models.load_model("path/to/the/pre-trained-model.h5")
model.summary()
(x_train, y_train), (x_test, y_test) = load_data()
# reshape or normalized the data if necessary
Fine-Tuning All Layers
Use SGD optimizer with smaller learning rate.
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=128,
epochs=10,
validation_data=(x_test, y_test))
Fine-Tuning Only Some Layers
For very big networks (which we're not using), we don't really need to fine-tune everything - the early layers are probably pretty good and we likely don't have enough data to tweak them in a meaningful way for example. As such, we may want to restrict what we fine tune. We can do that really easily.
In this example, we will just tune up the last 5 layers.
model = keras.models.load_model('../models/vgg_2stage_MNIST_small.h5')
# Freeze the layers except the last 5 layers
for layer in model.layers[:-5]:
layer.trainable = False
# Check the trainable status of the individual layers
for layer in model.layers:
print(layer, layer.trainable)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=128,
epochs=20,
validation_data=(x_test, y_test))
Change From Classification to Regression
Grab the output of a particular layer, and pass that to a new Dense layer, which we'll then pass to another Dense layer as output. Finally, we can create a model with the original model input, and the new model output.
model = keras.models.load_model("model/path.h5")
x = layers.Dense(64, activation="relu")(model.layers[-6].output)
outputs = layers.Dense(1)(x)
new_model = Keras.Model(inputs=model.input, outputs=outputs)
new_model.summary()
# train the model
new_model.compile(loss='mean_squared_error',
optimizer=keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9))
history = new_model.fit(x_train, y_train,
batch_size=128,
epochs=20,
validation_data=(x_val, y_val))
Final Thoughts
Fine tuning can be seen as a form of advanced initialization. Rather than initialize our network with random weights, we instead initialize it with weights learned on a (hopefully) related task.
The hope is that several layers of the network, in particular with CNNs the convolution layers that learn filters, will work well for the new task, and so the network can very quickly become accurate on the new data.
Often when fine tuning, there will be a need to replace some layers, particularly for the network output, to adapt the network to the new task.
This means that:
- these new layers will start from random initializations,
- the rest of the network is initialized with the previously learned weights.
In some cases, in particular where the network is very large, the tasks are quite similar, or data is very limited, a number of layers may be frozen such their weights are not trained at all.
This can be particularly useful with deep DCNNs, where the early convolutional layers that learn basic image primitives that typically translate very well across datasets can be frozen to further reduce training effort.