Main Content

Interpret Deep Network Predictions on Tabular Data Using LIME

This example shows how to use the locally interpretable model-agnostic explanations (LIME) technique to understand the predictions of a deep neural network classifying tabular data. You can use the LIME technique to understand which predictors are most important to the classification decision of a network.

In this example, you interpret a feature data classification network using LIME. For a specified query observation, LIME generates a synthetic data set whose statistics for each feature match the real data set. This synthetic data set is passed through the deep neural network to obtain a classification, and a simple, interpretable model is fitted. This simple model can be used to understand the importance of the top few features to the classification decision of the network. In training this interpretable model, synthetic observations are weighted by their distance from the query observation, so the explanation is "local" to that observation.

This example uses lime (Statistics and Machine Learning Toolbox) and fit (Statistics and Machine Learning Toolbox) to generate a synthetic data set and fit a simple interpretable model to the synthetic data set. To understand the predictions of a trained image classification neural network, use imageLIME. For more information, see Understand Network Predictions Using LIME.

Load Data

Load the Fisher iris data set. This data contains 150 observations with four input features representing the parameters of the plant and one categorical response representing the plant species. Each observation is classified as one of the three species: setosa, versicolor, or virginica. Each observation has four measurements: sepal width, sepal length, petal width, and petal length.

filename = fullfile(toolboxdir('stats'),'statsdemos','fisheriris.mat');
load(filename)

Convert the numeric data to a table.

features = ["Sepal length","Sepal width","Petal length","Petal width"];

predictors = array2table(meas,"VariableNames",features);
trueLabels = array2table(categorical(species),"VariableNames","Response");

Create a table of training data whose final column is the response.

data = [predictors trueLabels];

Calculate the number of observations, features, and classes.

numObservations = size(predictors,1);
numFeatures = size(predictors,2);
numClasses = length(categories(data{:,5}));

Split Data into Training, Validation, and Test Sets

Partition the data set into training, validation, and test sets. Set aside 15% of the data for validation and 15% for testing.

Determine the number of observations for each partition. Set the random seed to make the data splitting and CPU training reproducible.

rng('default');
numObservationsTrain = floor(0.7*numObservations);
numObservationsValidation = floor(0.15*numObservations);

Create an array of random indices corresponding to the observations and partition it using the partition sizes.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);
idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

Partition the table of data into training, validation, and testing partitions using the indices.

dataTrain = data(idxTrain,:);
dataVal = data(idxValidation,:);
dataTest = data(idxTest,:);

Define Network Architecture

Create a simple multi-layer perceptron, with a single hidden layer with five neurons and ReLU activations. The feature input layer accepts data containing numeric scalars representing features, such as the Fisher iris data set.

numHiddenUnits = 5;
layers = [
    featureInputLayer(numFeatures)
    fullyConnectedLayer(numHiddenUnits)
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Define Training Options and Train Network

Train the network using stochastic gradient descent with momentum (SGDM). Set the maximum number of epochs to 30 and use a mini-batch size of 15, as the training data does not contain many observations.

opts = trainingOptions("sgdm", ...
    "MaxEpochs",30, ...
    "MiniBatchSize",15, ...
    "Shuffle","every-epoch", ...
    "ValidationData",dataVal, ...
    "ExecutionEnvironment","cpu");

Train the network.

net = trainNetwork(dataTrain,layers,opts);
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       40.00% |       31.82% |       1.3060 |       1.2897 |          0.0100 |
|       8 |          50 |       00:00:00 |       86.67% |       90.91% |       0.4223 |       0.3656 |          0.0100 |
|      15 |         100 |       00:00:00 |       93.33% |       86.36% |       0.2947 |       0.2927 |          0.0100 |
|      22 |         150 |       00:00:00 |       86.67% |       81.82% |       0.2804 |       0.3707 |          0.0100 |
|      29 |         200 |       00:00:01 |       86.67% |       90.91% |       0.2268 |       0.2129 |          0.0100 |
|      30 |         210 |       00:00:01 |       93.33% |       95.45% |       0.2782 |       0.1666 |          0.0100 |
|======================================================================================================================|

Assess Network Performance

Classify observations from the test set using the trained network.

predictedLabels = net.classify(dataTest);
trueLabels = dataTest{:,end};

Visualize the results using a confusion matrix.

figure
confusionchart(trueLabels,predictedLabels)

The network successfully uses the four plant features to predict the species of the test observations.

Understand How Different Predictors Are Important to Different Classes

Use LIME to understand the importance of each predictor to the classification decisions of the network.

Investigate the two most important predictors for each observation.

numImportantPredictors = 2;

Use lime to create a synthetic data set whose statistics for each feature match the real data set. Create a lime object using a deep learning model blackbox and the predictor data contained in predictors. Use a low 'KernelWidth' value so lime uses weights that are focused on the samples near the query point.

blackbox = @(x)classify(net,x);
explainer = lime(blackbox,predictors,'Type','classification','KernelWidth',0.1);

You can use the LIME explainer to understand the most important features to the deep neural network. The function estimates the importance of a feature by using a simple linear model that approximates the neural network in the vicinity of a query observation.

Find the indices of the first two observations in the test data corresponding to the setosa class.

trueLabelsTest = dataTest{:,end};

label = "setosa";
idxSetosa = find(trueLabelsTest == label,2);

Use the fit function to fit a simple linear model to the first two observations from the specified class.

explainerObs1 = fit(explainer,dataTest(idxSetosa(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxSetosa(2),1:4),numImportantPredictors);

Plot the results.

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

For the setosa class, the most important predictors are a low petal length value and a high sepal width value.

Perform the same analysis for class versicolor.

label = "versicolor";
idxVersicolor = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVersicolor(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVersicolor(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

For the versicolor class, a high petal length value is important.

Finally, consider the virginica class.

label = "virginica";
idxVirginica = find(trueLabelsTest == label,2);

explainerObs1 = fit(explainer,dataTest(idxVirginica(1),1:4),numImportantPredictors);
explainerObs2 = fit(explainer,dataTest(idxVirginica(2),1:4),numImportantPredictors);

figure
subplot(2,1,1)
plot(explainerObs1);
subplot(2,1,2)
plot(explainerObs2);

For the virginica class, a high petal length value and a low sepal width value is important.

Validate LIME Hypothesis

The LIME plots suggest that a high petal length value is associated with the versicolor and virginica classes and a low petal length value is associated with the setosa class. You can investigate the results further by exploring the data.

Plot the petal length of each image in the data set.

setosaIdx = ismember(data{:,end},"setosa");
versicolorIdx = ismember(data{:,end},"versicolor");
virginicaIdx = ismember(data{:,end},"virginica");

figure
hold on
plot(data{setosaIdx,"Petal length"},'.')
plot(data{versicolorIdx,"Petal length"},'.')
plot(data{virginicaIdx,"Petal length"},'.')
hold off

xlabel("Observation number")
ylabel("Petal length")
legend(["setosa","versicolor","virginica"])

The setosa class has much lower petal length values than the other classes, matching the results produced from the lime model.

See Also

| | | | (Statistics and Machine Learning Toolbox) | (Statistics and Machine Learning Toolbox)

Related Topics