Main Content

taylorPrunableNetwork

Network that can be pruned by using first-order Taylor approximation

Description

A TaylorPrunableNetwork object enables support for pruning of filters in convolution layers by using first-order Taylor approximation. To prune filters in a dlnetwork object, first convert it to a TaylorPrunableNetwork object and then use the associated object functions.

To prune a deep neural network, you require the Deep Learning Toolbox™ Model Quantization Library support package. This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, see Deep Learning Toolbox Model Quantization Library.

Creation

Description

example

prunableNet = taylorPrunableNetwork(net) converts a dlnetwork object net to a TaylorPrunableNetwork object. The latter is a different representation of the same network that is suitable for pruning by using the Taylor pruning algorithm. If the input network cannot be pruned, this function produces an error.

prunableNet = taylorPrunableNetwork(layers) converts the network layers specified in layers to a TaylorPrunableNetwork object that is suitable for pruning by using the Taylor pruning algorithm. The input layers must be a LayerGraph object or a Layer array that can be converted to a dlnetwork object.

Input Arguments

expand all

Neural network, specified as a dlnetwork object.

Network layers, specified as a layerGraph object or as a Layer array.

Properties

expand all

Network learnable parameters, specified as a table with three columns:

  • Layer – Layer name, specified as a string scalar.

  • Parameter – Parameter name, specified as a string scalar.

  • Value – Value of parameter, specified as a dlarray object.

The network learnable parameters contain the features learned by the network. For example, the weights of convolution and fully connected layers.

Data Types: table

Network state, specified as a table.

The network state is a table with three columns:

  • Layer – Layer name, specified as a string scalar.

  • Parameter – State parameter name, specified as a string scalar.

  • Value – Value of state parameter, specified as a dlarray object.

Layer states contain information calculated during the layer operation to be retained for use in subsequent forward passes of the layer. For example, the cell state and hidden state of LSTM layers, or running statistics in batch normalization layers.

For recurrent layers, such as LSTM layers, with the HasStateInputs property set to 1 (true), the state table does not contain entries for the states of that layer.

During training or inference, you can update the network state using the output of the forward and predict functions.

Data Types: table

This property is read-only.

Network input layer names, specified as a cell array of character vectors.

Data Types: cell

Names of layers that return network outputs, specified as a cell array of character vectors or a string array.

If you do not specify the output names, then the software sets the OutputNames property to the layers with disconnected outputs. If a layer has multiple outputs, then the disconnected outputs are specified as 'layerName/outputName'.

The predict and forward functions, by default, return the data output by the layers given by the OutputNames property.

Data Types: cell | string

Number of convolution layer filters in the network that are suitable for pruning by using first-order Taylor approximation, specified as a nonnegative integer.

Object Functions

forwardCompute deep learning network output for training
predictCompute deep learning network output for inference
updatePrunablesRemove filters from prunable layers based on importance scores
updateScoreCompute and accumulate Taylor-based importance scores for pruning
dlnetworkDeep learning network for custom training loops

Examples

collapse all

This example shows how to prune a dlnetwork object by using a custom pruning loop.

Load dlnetwork Object

Load a trained dlnetwork object and the corresponding classes.

s = load("digitsCustom.mat");
dlnet_1 = s.dlnet;
classes = s.classes;

Inspect the layers of the dlnetwork object. The network has three convolution layers at locations 2, 5, and 8 of the Layer array.

layers_1 = dlnet_1.Layers
layers_1 = 
  12x1 Layer array with layers:

     1   'input'     Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv1'     2-D Convolution       20 5x5x1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'bn1'       Batch Normalization   Batch normalization with 20 channels
     4   'relu1'     ReLU                  ReLU
     5   'conv2'     2-D Convolution       20 3x3x20 convolutions with stride [1  1] and padding [1  1  1  1]
     6   'bn2'       Batch Normalization   Batch normalization with 20 channels
     7   'relu2'     ReLU                  ReLU
     8   'conv3'     2-D Convolution       20 3x3x20 convolutions with stride [1  1] and padding [1  1  1  1]
     9   'bn3'       Batch Normalization   Batch normalization with 20 channels
    10   'relu3'     ReLU                  ReLU
    11   'fc'        Fully Connected       10 fully connected layer
    12   'softmax'   Softmax               softmax

Load Data for Prediction

Load the digits data for prediction.

dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset");

imds = imageDatastore(dataFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

Partition the data into pruning and validation sets. Set aside 10% of the data for validation using the splitEachLabel function.

[imdsPrune,imdsValidation] = splitEachLabel(imds,0.9,"randomize");

The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the images, use augmented image datastores.

inputSize = [28 28 1];
augimdsPrune = augmentedImageDatastore(inputSize(1:2),imdsPrune);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Prune dlnetwork Object

Convert the dlnetwork object to a representation that is suitable for pruning by using the taylorPrunableNetwork function. This function returns a TaylorPrunableNetwork object that has the NumPrunables property set to 48. This indicates that 48 filters in the original model are suitable for pruning by using the Taylor pruning algorithm.

prunableNet_1 = taylorPrunableNetwork(dlnet_1)
prunableNet_1 = 
  TaylorPrunableNetwork with properties:

      Learnables: [14x3 table]
           State: [6x3 table]
      InputNames: {'input'}
     OutputNames: {'softmax'}
    NumPrunables: 48

Create a minibatchqueue object that processes and manages mini-batches of images during pruning. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to convert the labels to one-hot encoded variables.

  • Format the image data with the dimension labels "SSCB" (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not format the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

miniBatchSize = 128;
imds.ReadSize = miniBatchSize;

mbq = minibatchqueue(augimdsPrune, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat=["SSCB" ""]);

Calculate Taylor-based importance scores of the prunable filters in the network by looping over the mini-batches of data. For each mini-batch:

  • Calculate pruning activations and pruning gradients by using the modelLoss function defined at the end of this example

  • Update importance scores of the prunable filters by using the updateScore function

while hasdata(mbq)
    [X,T] = next(mbq);
    [~,pruningActivations,pruningGradients] = dlfeval(@modelLoss,prunableNet_1,X,T);
    prunableNet_1 = updateScore(prunableNet_1,pruningActivations,pruningGradients);
end

Finally, remove filters with the lowest importance scores to create a new TaylorPrunableNetwork object by using the updatePrunables function. By default, a single call to this function removes 8 filters. Observe that the new network prunableNet_2 has 40 prunable filters remaining.

prunableNet_2 = updatePrunables(prunableNet_1)
prunableNet_2 = 
  TaylorPrunableNetwork with properties:

      Learnables: [14x3 table]
           State: [6x3 table]
      InputNames: {'input'}
     OutputNames: {'softmax'}
    NumPrunables: 40

To further compress the model, run the custom pruning loop and update prunables again.

Extract Pruned dlnetwork Object

Use the dlnetwork function to extract the pruned dlnetwork object from the pruned TaylorPrunableNetwork object. You can now use this compressed dlnetwork object to perform inference.

dlnet_2 = dlnetwork(prunableNet_2);

Compare the convolution layers of the original and the pruned dlnetwork objects. Observe that the three convolution layers in the pruned network have fewer filters. These counts agree with the fact that, by default, a single call to the updatePrunables function removes 8 filters from the network.

conv_layers_1 = dlnet_1.Layers([2 5 8])
conv_layers_1 = 
  3x1 Convolution2DLayer array with layers:

     1   'conv1'   2-D Convolution   20 5x5x1 convolutions with stride [1  1] and padding [0  0  0  0]
     2   'conv2'   2-D Convolution   20 3x3x20 convolutions with stride [1  1] and padding [1  1  1  1]
     3   'conv3'   2-D Convolution   20 3x3x20 convolutions with stride [1  1] and padding [1  1  1  1]
conv_layers_2 = dlnet_2.Layers([2 5 8])
conv_layers_2 = 
  3x1 Convolution2DLayer array with layers:

     1   'conv1'   2-D Convolution   17 5x5x1 convolutions with stride [1  1] and padding [0  0  0  0]
     2   'conv2'   2-D Convolution   18 3x3x17 convolutions with stride [1  1] and padding [1  1  1  1]
     3   'conv3'   2-D Convolution   17 3x3x18 convolutions with stride [1  1] and padding [1  1  1  1]

Supporting Functions

Model Loss Function

The modelLoss function takes a TaylorPrunableNetwork object net, a mini-batch of input data X with corresponding targets T and returns activations in net and the gradients of the loss with respect to the activations in net. To compute the gradients automatically, this function uses the dlgradient function.

function [loss, pruningActivations, pruningGradients] = modelLoss(net,X,T)

% Calculate network output for training.
[out, ~, pruningActivations] = forward(net,X);

% Calculate loss.
loss = crossentropy(out,T);

% Compute pruning gradients.
pruningGradients = dlgradient(loss,pruningActivations);
end

Mini Batch Preprocessing Function

The preprocessMiniBatch function preprocesses a mini-batch of predictors and labels using the following steps:

  1. Preprocess the images using the preprocessMiniBatchPredictors function.

  2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,T] = preprocessMiniBatch(dataX,dataT)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(dataX);

% Extract label data from cell and concatenate.
T = cat(2,dataT{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

end

Mini-Batch Predictors Preprocessing Function

The preprocessMiniBatchPredictors function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenating into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.

function X = preprocessMiniBatchPredictors(dataX)

% Concatenate.
X = cat(4,dataX{1:end});

% Normalize the images.
X = X/255;

end

More About

expand all

Algorithms

For an individual input data point in the pruning dataset, you use the forward function to calculate the output of the deep learning network and the activations of the prunable filters. Then you calculate the gradients of the loss with respect to these activations using automatic differentiation. You then pass the network, the activations, and the gradients to the updateScore function. For each prunable filter in the network, the updateScore function calculates the change in loss that occurs if that filter is pruned from the network (up to first-order Taylor approximation). Based on this change, the function associates an importance score with that filter and updates the TaylorPrunableNetwork object [1].

Inside the custom pruning loop, you accumulate importance scores for the prunable filters over all mini-batches of the pruning dataset. Then you pass the network object to the updatePrunables function. This functions prunes the filters that have the lowest importance scores and hence have the smallest effect on the accuracy of the network output. The number of filters that a single call to the updatePrunables function prunes is determined by the optional name-value argument MaxToPrune, that has a default value of 8.

All these steps complete a single pruning iteration. To further compress your model, repeat these steps multiple times over a loop.

References

[1] Molchanov, Pavlo, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. "Pruning Convolutional Neural Networks for Resource Efficient Inference." Preprint, submitted June 8, 2017. https://arxiv.org/abs/1611.06440.

Version History

Introduced in R2022a