Main Content

Compare Deep Learning Models Using ROC Curves

This example shows how to use receiver operating characteristic (ROC) curves to compare the performance of deep learning models.

A ROC curve shows the true positive rate (TPR), or sensitivity, versus the false positive rate (FPR), or 1-specificity, for different thresholds of classification scores. The area under a ROC curve (AUC) corresponds to the integral of the curve (TPR values) with respect to FPR values from zero to one. The AUC provides an aggregate performance measure across all possible thresholds. The AUC values are in the range [0, 1], and larger AUC values indicate better classifier performance.

  • A perfect classifier always correctly assigns positive class observations to the positive class and has a TPR of 1 for all threshold values.

  • A random classifier returns random score values and has the same values for the FPR and TPR for all threshold values.

For a multiclass classification problem, the rocmetrics function formulates a set of one-versus-all binary classification problems with one binary problem for each class and finds a ROC curve for each class using the corresponding binary problem. Each binary problem assumes one class as positive and the rest as negative.

This example shows how to use ROC curves and AUC values to compare two methods of training a deep neural network for image classification.

  • Train a small network from scratch.

  • Adapt a pretrained GoogLeNet network for new data using transfer learning.

Load Data

Download and extract the Flowers [1] data set. The Flowers data set contains 3670 images of flowers belonging to five classes (daisy, dandelion, roses, sunflowers, and tulips).

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

dataFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(dataFolder,"dir")
    fprintf("Downloading Flowers data set (218 MB)... ")
    websave(filename,url);
    untar(filename,downloadFolder)
    fprintf("Done.\n")
end
Downloading Flowers data set (218 MB)... 
Done.
numClasses = 5;

Create an image datastore containing the photos of the flowers.

imds = imageDatastore(dataFolder,IncludeSubfolders=true,LabelSource="foldernames");

Partition the data into training, validation, and test sets. Set aside 20% of the data for validation and 20% of the data for testing using the splitEachLabel function.

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2,0.2,"randomize");

Prepare Networks

Create two image classification models. For the first model, build and train a deep neural network from scratch. For the second model, adapt a pretrained GoogLeNet network for new data using transfer learning. This example requires the Deep Learning Toolbox™ Model for GoogLeNet Network support package. If this support package is not installed, then the googlenet function provides a download link. The GoogLeNet network requires images of size 224-by-224-by-3 pixels.

inputSize = [224 224 3];

Create New Network

Create a small network from scratch. Set the input size to match the input size of the GoogLeNet pretrained network. To reduce overfitting, include a dropout layer.

numFilters = 16;
filterSize = 3;
poolSize = 2;

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(filterSize,Stride=2)

    convolution2dLayer(filterSize,2*numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(poolSize,Stride=2)

    convolution2dLayer(filterSize,4*numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(poolSize,Stride=2)
   
    dropoutLayer(0.8)

    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

lgraphSmallNet = layerGraph(layers);

Prepare GoogLeNet Network

Adapt a pretrained GoogLeNet network for the new data.

Load GoogLeNet.

lgraphGoogLeNet = layerGraph(googlenet);

To use a pretrained network for transfer learning, you must adapt the network to match your new data set.

  • Replace the last learnable layer with a new layer that is adapted to the new data. For GoogLeNet, this layer is the final fully connected layer, loss3-classifier. Set the output size in the new layer to match the number of classes in the new data. Increase the learning in the new layer by increasing the weight and bias learn rate factors. This increase ensures that learning is faster in the new layer than in the transferred layers.

  • Replace the output layer, output, with a new output layer that is adapted to the new data.

newLearnableLayer = fullyConnectedLayer(numClasses, ...
    WeightLearnRateFactor=10, ...
    BiasLearnRateFactor=10);
lgraphGoogLeNet = replaceLayer(lgraphGoogLeNet,"loss3-classifier",newLearnableLayer);

newOutputLayer = classificationLayer("Name","ClassificationLayer_predictions");
lgraphGoogLeNet = replaceLayer(lgraphGoogLeNet,"output",newOutputLayer);

Compare Networks

Compare the size of the networks using analyzeNetwork.

analyzeNetwork(lgraphGoogLeNet)
analyzeNetwork(lgraphSmallNet)

The small network has 17 layers and nearly 300,000 learnable parameters. The larger GoogleNet network has 144 layers and nearly 6 million learnable parameters. Although the pretrained network is larger, you do not need to train it for as long when you perform transfer learning. This reduction in training time is because the network has already learned features that you can use as a starting point for your new data.

Prepare Data

The networks require input images of size 224-by-224-by-3. To automatically resize the training images, use an augmented image datastore. Specify additional augmentation operations to perform on the training images: randomly flip the training images along the vertical axis and randomly scale them up to 50% horizontally and vertically. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

augmenter = imageDataAugmenter(RandXReflection=true,RandScale=[0.5 1.5]);
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain,DataAugmentation=augmenter);

To automatically resize the validation images without performing further data augmentation, use an augmented image datastore without specifying any additional preprocessing operations.

augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);

Training Options

Train the small network for 150 epochs with an initial learning rate of 0.002.

optsSmallNet = trainingOptions("sgdm", ...
    MaxEpochs=150, ...
    InitialLearnRate=0.002, ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=150, ...
    Verbose=false, ...
    Plots="training-progress");

You do not need to train the pretrained network for as many epochs, so set the maximum number of epochs to 15. Previously, you increased the learning rate in the new learnable layer. To slow the learning in the earlier layers of the pretrained network, choose a small initial learning rate of 0.0001.

optsGoogLeNet = optsSmallNet;
optsGoogLeNet.MaxEpochs = 15;
optsGoogLeNet.InitialLearnRate = 0.0001;

Train Networks

Train the networks using trainNetwork. Despite being larger, the adapted GoogLeNet network converges quicker than the small network.

netSmallNet = trainNetwork(augimdsTrain,lgraphSmallNet,optsSmallNet);

{"String":"Figure Training Progress (16-Jun-2022 10:08:35) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 8 objects of type patch, text, line. Axes object 2 contains 8 objects of type patch, text, line.","Tex":[],"LaTex":[]}

netGoogLeNet = trainNetwork(augimdsTrain,lgraphGoogLeNet,optsGoogLeNet);

{"String":"Figure Training Progress (16-Jun-2022 10:23:42) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 8 objects of type patch, text, line. Axes object 2 contains 8 objects of type patch, text, line.","Tex":[],"LaTex":[]}

Compare Network Accuracy

Test the classification accuracy of the two networks by comparing the predictions on the test set with the true labels.

Prepare the test data.

augimdsTest = augmentedImageDatastore(inputSize,imdsTest);

Classify the test images using the two networks.

[YTestSmallNet,scoresSmallNet] = classify(netSmallNet,augimdsTest);
[YTestGoogLeNet,scoresGoogLeNet] = classify(netGoogLeNet,augimdsTest);

Compare the accuracy of the two networks.

TTest = imdsTest.Labels;
accSmallNet = sum(TTest == YTestSmallNet)/numel(TTest)
accSmallNet = 0.7361
accGoogLeNet = sum(TTest == YTestGoogLeNet)/numel(TTest)
accGoogLeNet = 0.9034

Plot confusion charts for each mode. For each class, the GoogLeNet network performs better than the smaller network. Both networks have the greatest difficulty in classifying images from the daisy and rose classes.

figure 
tiledlayout(1,2)
nexttile
confusionchart(TTest,YTestSmallNet)
title("SmallNet")
nexttile
confusionchart(TTest,YTestGoogLeNet)
title("GoogLeNet")

Figure contains objects of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title SmallNet. The chart of type ConfusionMatrixChart has title GoogLeNet.

Compare ROC Curves

You can use ROC curves to compare the performance of the two networks.

Create rocmetrics objects using the true labels in TTest and the classification scores from each of the trained networks. Specify the column order of the classification scores by extracting the class names from the output layers of each network.

classNames = netSmallNet.Layers(end).Classes;
rocSmallNet = rocmetrics(TTest,scoresSmallNet,classNames);
rocGoogLeNet = rocmetrics(TTest,scoresGoogLeNet,classNames);

rocSmallNet and rocGoogLeNet are rocmetrics objects that store the AUC values and performance metrics for each class in the AUC and Metrics properties. Plot the ROC curves for each class. You can click on any part of the ROC curve to see the threshold corresponding to the TPR and FPR values that you select.

The diagonal line indicates the performances of a random classifier. The smaller network performs the best for the sunflower and dandelion classes. However, across all five classes, the larger network performs better than the smaller network.

figure
tiledlayout(1,2,TileSpacing="compact")
nexttile
plot(rocSmallNet,ShowModelOperatingPoint=false)
legend(classNames)
title("ROC Curve: SmallNet")
nexttile
plot(rocGoogLeNet,ShowModelOperatingPoint=false)
legend(classNames)
title("ROC Curve: GoogLeNet")

Figure contains 2 axes objects. Axes object 1 with title ROC Curve: SmallNet contains 6 objects of type roccurve, line. These objects represent daisy, dandelion, roses, sunflowers, tulips. Axes object 2 with title ROC Curve: GoogLeNet contains 6 objects of type roccurve, line. These objects represent daisy, dandelion, roses, sunflowers, tulips.

Compare AUC Values

You can access the AUC value for each class using the rocmetrics object.

aucSmallNet = rocSmallNet.AUC;
aucGoogLeNet = rocGoogLeNet.AUC;

Compare the AUC values for each class. The AUC values provide an aggregate performance measure across all possible thresholds. The AUC values are in the range [0, 1], and larger AUC values indicate better classifier performance. For each class, the GoogLeNet network produces AUC values close to 1.

figure
bar([aucSmallNet; aucGoogLeNet]')
xticklabels(classNames)
legend(["SmallNet","GoogLeNet"],Location="southeast")
title("AUC")

Figure contains an axes object. The axes object with title AUC contains 2 objects of type bar. These objects represent SmallNet, GoogLeNet.

Investigate Specific Class

Investigate the ROC curves for the sunflowers class. By default, the plot function displays the class names and the AUC values in the legend. To include the model names in the legend instead of the class names, modify the DisplayName property of the ROCCurve object that the plot function returns. The model operating point represents the FPR and TPR corresponding to the typical threshold value. For the sunflower class, both models are performing well.

classToInvestigate = "sunflowers";

figure
c = cell(2,1);
g = cell(2,1);
[c{1},g{1}] = plot(rocSmallNet,ClassNames=classToInvestigate);
hold on
[c{2},g{2}] = plot(rocGoogLeNet,ClassNames=classToInvestigate);
modelNames = ["SmallNet","GoogLeNet"];
for i = 1:2
    c{i}.DisplayName = replace(c{i}.DisplayName, ...
        classToInvestigate,modelNames(i));
    g{i}(1).DisplayName = join([modelNames(i),"Model Operating Point"]);
end
title("ROC Curve","Class: " + classToInvestigate)
hold off

Figure contains an axes object. The axes object with title ROC Curve contains 6 objects of type roccurve, scatter, line. These objects represent SmallNet (AUC = 0.9573), SmallNet Model Operating Point, GoogLeNet (AUC = 0.993), GoogLeNet Model Operating Point.

Compare Average ROC Curves

Find the average ROC curves. Specify AverageROCType as "macro" to compute metrics for the average ROC curve using the macro-averaging method. Macro-averaging finds the average values of the FPR and TPR by averaging the values of the one-versus-all binary classification problems for each class. To learn more, see Average of Performance Metrics.

figure
averageType = "macro";
plot(rocSmallNet,AverageROCType=averageType,ClassNames=[])
hold on
plot(rocGoogLeNet,AverageROCType=averageType,ClassNames=[])
legend(["SmallNet (" + averageType + "-average)", ...
    "GoogLeNet (" + averageType + "-average)"])
hold off

Figure contains an axes object. The axes object with title ROC Curve contains 4 objects of type roccurve, line. These objects represent SmallNet (macro-average), GoogLeNet (macro-average).

References

[1] The TensorFlow Team. Flowers. http://download.tensorflow.org/example_images/flower_photos.tgz

See Also

| |

Related Topics