@Gediminas Simkus had the right idea for the workaround. I can sketch this out a bit more.
To make predictions with the network after training, batch normalization requires a fixed mean and variance to normalize the data. By default, this fixed mean and variance is calculated from the training data at the very end of training using the entire training data set. But when using checkpointing, the end of training isn't reached so the mean and variance values are not set.
Two possible solutions
There are two things you can try in order to use checkpoint networks for inference:
- Since R2021a, running statistics can be enabled for batch normalization layers. The batch normalization statistics are then calculated during training and not at the end of training. The checkpoint networks can be used directly without further modification. To do this, set the ‘BatchNormalizationStatistics’ name-value pair in trainingOptions to ‘moving’ when training the network with checkpointing.
- Use trainNetwork with minimal training to convert the checkpoint network into a network with fixed batch normalization mean and variance that can be used for inference. The workaround is based on the process to Resume Training from Checkpoint Network but with some slight tweaks in order to modify the checkpointed network as little as possible.
Example steps for second workaround using trainNetwork (tested in R2020a and R2020b)
Load the checkpoint network into the workspace (replace this with your own file).
Specify the training options such that training is only run for one iteration, the input data statistics of the input layer are not recomputed, and the learnable parameters are only changed minimally.
options = trainingOptions('sgdm', ...
Now “resume” training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
The returned network can be used for inference.
YPred = classify(net2,XTrain);
I hope this helps.