Main Content

Compress Neural Network Using Projection

This example shows how to compress a neural network using projection and principal component analysis.

To compress a deep learning network, you can use projected layers. The layer introduces learnable projector matrices Q, replaces multiplications of the form Wx, where W is a learnable matrix, with the multiplication WQQx, and stores Q and W=WQ instead of storing W. Projecting x into a lower dimensional space using Q typically requires less memory to store the learnable parameters and can have similarly strong prediction accuracy. A projected deep neural network can also exhibit faster forward passes when run on the CPU or deployed to embedded hardware using library-free C or C++ code generation.

The compressNetworkUsingProjection function compresses a network by projecting layers into smaller parameter subspaces. For optimal initialization of the projected network, the function projects the learnable parameters of projectable layers into a subspace that maintains the highest variance in neuron activations. After you compress a neural network using projection, you can then fine-tune the network to increase the accuracy.

This chart shows the effect of projection and fine tuning on a trained network. In this case, the projected network has significantly fewer learnable parameters at the cost of classification accuracy. The fine-tuned projected network yields similar classification accuracy to the original network.

compression_results.png

Load Pretrained Network

Load the pretrained network in dlnetJapaneseVowels.

load dlnetJapaneseVowels

View the network layers. The network is a LSTM network with a single LSTM layer with 100 hidden units.

net.Layers
ans = 
  4×1 Layer array with layers:

     1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
     2   'lstm'            LSTM              LSTM with 100 hidden units
     3   'fc'              Fully Connected   9 fully connected layer
     4   'softmax'         Softmax           softmax

View the class names of the network.

classNames
classNames = 9×1 string
    "1"
    "2"
    "3"
    "4"
    "5"
    "6"
    "7"
    "8"
    "9"

This example trains several networks. For comparison, create a copy of the original network.

netOriginal = net;

Load Training Data

Load the Japanese Vowels data set described in [1] and [2]. XTrain is a cell array containing 270 sequences of varying length with 12 features corresponding to LPC cepstrum coefficients. TTrain is a categorical vector of labels 1, 2, ..., 9. The entries in XTrain are matrices with 12 rows (one row for each feature) and a varying number of columns (one column for each time step).

[XTrain,TTrain] = japaneseVowelsTrainData;

Analyze Neuron Activations for Compression Using Projection

The compressNetworkUsingProjection function uses principal component analysis (PCA) to identify the subspace of learnable parameters that result in the highest variance in neuron activations by analyzing the network activations using a data set of training data. This analysis requires only the predictors of the training data to compute the network activations. It does not require the training targets.

The PCA step can be computationally intensive. If you expect to compress the same network multiple times (for example, when exploring different levels of compression), then perform the PCA step first and reuse the resulting neuronPCA object.

Create a mini-batch queue containing the training data. To create a mini-batch queue from in-memory data, convert the sequences to an array datastore.

adsXTrain = arrayDatastore(XTrain,OutputType="same");

Create the minibatchqueue object.

  • Specify a mini-batch size of 16.

  • Preprocess the mini-batches using the preprocessMiniBatchPredictors function, listed in the Mini-Batch Predictors Preprocessing Function section of the example.

  • Specify that the output data has format "CTB" (channel, time, batch).

Note: Do not pad sequence data when doing the PCA step for projection as this can negatively impact the analysis. Instead, truncate mini-batches of data to have the same length or use mini-batches of size 1.

miniBatchSize = 16;

mbqTrain = minibatchqueue(adsXTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatchPredictors, ...
    MiniBatchFormat="CTB");

Create the neuronPCA object. To view information about the steps of the neuron PCA algorithm, set the VerbosityLevel option to "steps".

npca = neuronPCA(netOriginal,mbqTrain,VerbosityLevel="steps");
Computing layer activations and covariance matrices...
Computing eigenvalues and eigenvectors...
neuronPCA analyzed 1 layers: "lstm"

View the properties of the neuronPCA object.

npca
npca = 
  neuronPCA with properties:

            LayerNames: "lstm"
      InputEigenvalues: {[12×1 double]}
     InputEigenvectors: {[12×12 double]}
     OutputEigenvalues: {[100×1 double]}
    OutputEigenvectors: {[100×100 double]}

Project Network

Compress the network using the neuron PCA object.

netProjected = compressNetworkUsingProjection(netOriginal,npca);
Compressed network has 82.4% fewer learnable parameters.
Projected layers explain on average 96.6% of layer activation variance.

Test Projected Network

Load the Japanese Vowels test data set.

[XTest,TTest] = japaneseVowelsTestData;

Create a mini-batch queue using the same steps as the training data.

adsTest = arrayDatastore(XTest,OutputType="same");

mbqTest = minibatchqueue(adsTest, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatchPredictors, ...
    MiniBatchFormat="CTB");

For comparison, calculate the classification accuracy of the original network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(netOriginal,mbqTest,classNames);
accOriginal = mean(YTest == TTest)
accOriginal = 0.9135

Calculate the classification accuracy of the projected network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(netProjected,mbqTest,classNames);
accProjected = mean(YTest == TTest)
accProjected = 0.4784

Compare the accuracy and the number of learnables of each network in a bar chart. To calculate the number of learnables of each network, use the numLearnables function, listed in the Number of Learnables Function section of the example.

figure
tiledlayout("flow")

nexttile
bar([accOriginal accProjected])
xticklabels(["Original" "Projected"])
title("Accuracy")
ylabel("Accuracy")

nexttile
bar([numLearnables(netOriginal) numLearnables(netProjected)])
xticklabels(["Original" "Projected"])
ylabel("Number of Learnables")
title("Number of Learnables")

The projected network yields worse classification accuracy and has significantly fewer learnable parameters.

Compress for Memory Requirement

If you want to compress a network so that it meets specific hardware memory requirements, then you can manually calculate the learnable reduction value such that the compressed network is of the desired size.

Specify a target memory requirement of 64 kilobytes (64×1024 bytes).

targetMemorySize = 64*1024
targetMemorySize = 65536

Calculate the memory size of the original network using the parameterMemory function, listed in the Parameter Memory Function section of the example.

memorySizeOriginal = parameterMemory(netOriginal)
memorySizeOriginal = 184436

Calculate the factor to reduce the learnables by such that the resulting network meets the memory requirements.

reductionGoal = 1 - (targetMemorySize/memorySizeOriginal);

Project the network using the compressNetworkUsingProjection function and set the LearnablesReductionGoal option to the calculated reduction factor.

netProjected = compressNetworkUsingProjection(netOriginal,npca, ...
    LearnablesReductionGoal=reductionGoal);
Compressed network has 64.6% fewer learnable parameters.
Projected layers explain on average 99.7% of layer activation variance.

Calculate the memory size of the projected network using the parameterMemory function, listed in the Parameter Memory Function section of the example.

memorySizeProjected = parameterMemory(netProjected)
memorySizeProjected = 65364

Calculate the classification accuracy of the projected network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(netProjected,mbqTest,classNames);
accProjected = mean(YTest == TTest)
accProjected = 0.8649

Compare the accuracy and the memory size of each network in a bar chart.

figure
tiledlayout("flow")

nexttile
bar([accOriginal accProjected])
xticklabels(["Original" "Projected"])
ylabel("Accuracy")
title("Accuracy")

nexttile
bar([memorySizeOriginal memorySizeProjected])
xticklabels(["Original" "Projected"])
yline(targetMemorySize,"r--","Memory Requirement")
ylabel("Memory (bytes)")
title("Memory Size")

The projected network yields similar classification accuracy and has memory size that meets the memory requirements.

Explore Compression Levels

There is a trade-off between the amount of compression and the network accuracy. In particular, reducing the number of learnable parameters typically reduces the network accuracy.

The explained variance of a network details how well the space of network activations can capture the underlying features of the data. To explore different amounts of compression, you can iterate over different values of the ExplainedVarianceGoal option of the compressNetworkUsingProjection function and compare the results.

Loop over different values of the explained variance goal. Iterate over 20 logarithmically spaced values between 0.999 and 0.

For each value:

  • Compress the network using projection with the specified explained variance goal using the compressNetworkUsingProjection function. Suppress verbose output by setting the VerbosityLevel option to "off".

  • Record the actual explained variance and learnables reduction of the projected network.

  • Calculate the classification accuracy of the projected network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

numValues = 20;
explainedVarGoal = 1 - logspace(-3,0,numValues);

for i = 1:numel(explainedVarGoal)
    varianceGoal = explainedVarGoal(i);

    [netProjected,info] = compressNetworkUsingProjection(netOriginal,npca, ...
        ExplainedVarianceGoal=varianceGoal, ...
        VerbosityLevel="off");

    explainedVariance(i) = info.ExplainedVariance;
    learnablesReduction(i) = info.LearnablesReduction;

    YTest = modelPredictions(netProjected,mbqTest,classNames);
    accuracy(i) = mean(YTest==TTest);
end

Visualize the effect of the different settings of the explained variance goal in a plot.

figure
tiledlayout("flow")

nexttile
plot(learnablesReduction,accuracy,'+-')
ylabel("Accuracy")
title("Effect of Explained Variance Goal")

nexttile
plot(learnablesReduction,explainedVariance,'+-')
ylim([0 inf])
ylabel("Explained Variance")
xlabel("Learnable Reduction")

The graphs show that an increase in learnable reduction has a corresponding decrease in the explained variance and accuracy. A learnable reduction value of around 85% shows a very slight decrease in explained variance and a small decrease in accuracy.

Compress the network using projection with a learnable reduction goal of 85% using the compressNetworkUsingProjection function. Suppress verbose output by setting the VerbosityLevel option to "off".

netProjected = compressNetworkUsingProjection(netOriginal,npca, ...
    LearnablesReduction=0.85, ...
    VerbosityLevel="off");

Calculate the classification accuracy of the projected network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(netProjected,mbqTest,classNames);
accProjected = mean(YTest == TTest)
accProjected = 0.4270

Compare the accuracy and the number of learnables of each network in a bar chart. To calculate the number of learnables of each network, use the numLearnables function, listed in the Number of Learnables Function section of the example.

figure
tiledlayout("flow")

nexttile
bar([accOriginal accProjected])
xticklabels(["Original" "Projected"])
ylabel("Accuracy")
title("Accuracy")

nexttile
bar([numLearnables(netOriginal) numLearnables(netProjected)])
xticklabels(["Original" "Projected"])
ylabel("Number of Learnables")
title("Number of Learnables")

The projected network yields worse classification accuracy and has significantly fewer learnable parameters. You can improve the network accuracy by fine tuning the network.

Fine-Tune Compressed Network

Compressing a network using projection typically reduces the network accuracy. You can improve the accuracy by retraining the network also known as fine tuning the network.

Specify the options for fine-tuning. Train for 30 epochs with a learning rate of 0.0005.

numEpochs = 30;
learnRate = 0.0005;

The projection steps of the workflow required the training predictors only. Retraining the network requires both the predictors and the labels. Create a mini-batch object that outputs both the predictors and the labels.

Create a combined datastore that outputs the training predictors and the labels by combining the array datastore of predictors with an array datastore of the labels.

adsTTrain = arrayDatastore(TTrain,IterationDimension=1);
cdsTrain = combine(adsXTrain,adsTTrain);

Create a mini-batch queue that outputs mini-batches of predictors and labels:

  • Specify that the mini-batch queue has two outputs.

  • Specify the same a mini-batch size of 16.

  • Preprocess the mini-batches using the preprocessMiniBatch function, listed in the Mini-Batch Preprocessing Function section of the example.

  • Specify that the predictors have formats "CTB" (channel, time, batch) and that the preprocessed labels have format "CB" (channel, batch).

miniBatchSize = 16;

mbqTrain = minibatchqueue(cdsTrain,2, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat=["CTB" "CB"]);

Initialize the parameters for Adam optimization.

averageGrad = [];
averageSqGrad = [];

For comparison, create a copy of the network object to train.

netFineTuned = netProjected;

Calculate the total number of iterations for the training progress monitor.

numObservationsTrain = size(XTrain,1);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.

monitor = trainingProgressMonitor( ...
    Metrics="Loss", ...
    Info=["Epoch","LearnRate"], ...
    XLabel="Iteration");

Fine tune the network using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:

  • Evaluate the model loss and gradients using the modelLoss function, listed in the Model Loss Function section of the example.

  • Update the network parameters using the adamupdate function.

  • Update the loss, learn rate, and epoch values in the training progress monitor.

  • Stop if the Stop property is true. The Stop property value of the TrainingProgressMonitor object changes to true when you click the Stop button.

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop

    epoch = epoch + 1;

    % Shuffle data.
    shuffle(mbqTrain);

    % Loop over mini-batches.
    while hasdata(mbqTrain) && ~monitor.Stop
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(mbqTrain);

        % Evaluate the model gradients and loss using dlfeval and the
        % modelLoss function.
        [loss,gradients] = dlfeval(@modelLoss,netFineTuned,X,T);

        % Update the network parameters using the ADAM optimizer.
        [netFineTuned,averageGrad,averageSqGrad] = adamupdate(netFineTuned,gradients, ...
            averageGrad,averageSqGrad,iteration,learnRate);

        % Update the training progress monitor.
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);
        monitor.Progress = 100 * iteration/numIterations;
    end
end

Test Fine-Tuned Network

Calculate the classification accuracy of the fine-tuned network using the test data and the modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(netFineTuned,mbqTest,classNames);
accFineTuned = mean(YTest == TTest)
accFineTuned = 0.9081

Compare the accuracy and the number of learnables of each network in a bar chart. To calculate the number of learnables of each network, use the numLearnables function, listed in the Number of Learnables Function section of the example.

figure
tiledlayout("flow")
nexttile
bar([accOriginal accProjected accFineTuned])
xticklabels(["Original" "Projected" "Fine-Tuned Projected"])
title("Accuracy")
ylabel("Accuracy")

nexttile
bar([numLearnables(netOriginal) numLearnables(netProjected) numLearnables(netFineTuned)])
xticklabels(["Original" "Projected" "Fine-Tuned Projected"])
ylabel("Number of Learnables")
title("Number of Learnables")

The projected network has significantly fewer learnable parameters at the cost of classification accuracy. The fine-tuned projected network yields similar classification accuracy to the original network.

Supporting Functions

Mini-Batch Predictors Preprocessing Function

The preprocessMiniBatchPredictors function preprocesses a mini-batch of predictors by extracting the sequence data from the input cell array and truncating them along the second dimension so that they have the same length.

Note: Do not pad sequence data when doing the PCA step for projection as this can negatively impact the analysis. Instead, truncate mini-batches of data to have the same length or use mini-batches of size 1.

function X = preprocessMiniBatchPredictors(dataX)

X = padsequences(dataX,2,Length="shortest");

end

Preprocess Mini-Batch Function

The preprocessMiniBatch function preprocesses a mini-batch of predictors and labels using the following steps:

  1. Preprocess the sequences using the preprocessMiniBatchPredictors function.

  2. Extract the label data and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,T] = preprocessMiniBatch(dataX,dataT)

X = preprocessMiniBatchPredictors(dataX);

dataT = cat(2,dataT{:});
T = onehotencode(dataT,1);

end

Number of Learnables Function

The numLearnables function returns the total number of learnables in a network.

function N = numLearnables(net)

N = 0;
for i = 1:size(net.Learnables,1)
    N = N + numel(net.Learnables.Value{i});
end

end

Parameter Memory Function

The parameterMemory function returns the size in bytes of the learnable parameters of a network, where the learnable parameters are in single precision (4 bytes per learnable).

function numBytes = parameterMemory(net)

numBytes = 4*numLearnables(net);

end

Model Loss Function

The modelLoss function takes a dlnetwork object net, a mini-batch of input data X with corresponding targets T and returns the loss and the gradients of the loss with respect to the learnable parameters in net.

function [loss,gradients] = modelLoss(net,X,T)

Y = forward(net,X);

loss = crossentropy(Y,T);

gradients = dlgradient(loss,net.Learnables);

end

Model Predictions Function

The modelPredictions function takes a dlnetwork object net, a minibatchqueue of input data mbq, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted class with the highest score.

function Y = modelPredictions(net,mbq,classNames)

Y = [];
reset(mbq)

while hasdata(mbq)
    X = next(mbq);

    scores = predict(net,X);

    labels = onehotdecode(scores,classNames,1)';

    Y = [Y; labels];
end

end

Bibliography

  1. M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

  2. UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels