How to classify with DAG network from checkpoint
15 views (last 30 days)
I want to use classify() with DAG network from checkpoint network.
I trained inceptionv3 by transfer learning for a long epochs and it was successed. I set 'CheckpointPath' and have networks at each epoch. I want to evaluate these networks, so I loaded one and used classify(). But error message occuerd and it said "Use trainNetwork". How can I use classify() with network loaded from checkpoint?
Naoya on 15 Oct 2018
Thank you very much for providing the details.
The checkpoint network containing BatchNormalization layers is not supported on the current latest release (R2018b). I will forward this functionality as an enhancement request to our development team.
We applogize for causing inconvenience on the current checkpoint functionality.
More Answers (3)
Katja Mogalle on 30 Apr 2021
@Gediminas Simkus had the right idea for the workaround. I can sketch this out a bit more.
To make predictions with the network after training, batch normalization requires a fixed mean and variance to normalize the data. By default, this fixed mean and variance is calculated from the training data at the very end of training using the entire training data set. But when using checkpointing, the end of training isn't reached so the mean and variance values are not set.
Two possible solutions
There are two things you can try in order to use checkpoint networks for inference:
- Since R2021a, running statistics can be enabled for batch normalization layers. The batch normalization statistics are then calculated during training and not at the end of training. The checkpoint networks can be used directly without further modification. To do this, set the ‘BatchNormalizationStatistics’ name-value pair in trainingOptions to ‘moving’ when training the network with checkpointing.
- Use trainNetwork with minimal training to convert the checkpoint network into a network with fixed batch normalization mean and variance that can be used for inference. The workaround is based on the process to Resume Training from Checkpoint Network but with some slight tweaks in order to modify the checkpointed network as little as possible.
Example steps for second workaround using trainNetwork (tested in R2020a and R2020b)
Load the checkpoint network into the workspace (replace this with your own file).
Specify the training options such that training is only run for one iteration, the input data statistics of the input layer are not recomputed, and the learnable parameters are only changed minimally.
options = trainingOptions('sgdm', ...
Now “resume” training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
The returned network can be used for inference.
YPred = classify(net2,XTrain);
I hope this helps.
Andrea Daou on 8 Oct 2021
I know an answer was accepted for this question but I have a response that might be useful.
If the use of network from checkpoint does not work in your MATLAB version, you can write a function similar to the one in https://fr.mathworks.com/help/deeplearning/ug/customize-output-during-deep-learning-training.html .
For example, instead of being based on Validation Accuracy, it can be based on Validation Loss.
function stop = stopIfValidationLossNotDecreasing(info,N,StartPoint)
stop = false;
% Keep track of the validation loss and the number of successive validations for which
% there has not been a decrease in the loss.
% Clear the variables when training starts.
if info.State == "start"
ValLoss = StartPoint; % Value chosen depending on the problem case; check first validation loss.
valLag = 0;
% Compare the current validation loss to the last validation loss; if
% the new validation loss is less than the validation loss that
% precedes it then reset valLag else increment valLag by 1. Now the new
% ValLoss to compare with is the last one reached.
if info.ValidationLoss < ValLoss
valLag = 0;
ValLoss = info.ValidationLoss;
valLag = valLag + 1;
ValLoss = info.ValidationLoss;
% If the validation lag is at least N, that is, the validation loss
% has not decreased for at least N validations in a row, then return true and
% stop training.
if valLag >= N
stop = true;