How can I obtain the Shapley values from a Neural Network Object?
10 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
MathWorks Support Team
le 8 Juil 2021
Modifié(e) : MathWorks Support Team
le 18 Juin 2024
I have created a neural network for pattern recognition with the "patternnet" function and would like the calculate its Shapley values by executing this code:
[x,t] = iris_dataset;
net = patternnet(7);
net = train(net,x,t)
queryPoint=x(:,1)'
explainer = shapley(net,x,'QueryPoint',queryPoint)
However, I receive the following error:
Error using shapley
Blackbox model must be a classification model, regression model, or function handle
Is there a way to obtain the Shapley values from a "network" object as the one above?
Réponse acceptée
MathWorks Support Team
le 18 Juin 2024
Modifié(e) : MathWorks Support Team
le 18 Juin 2024
It is possible to obtain Shapley values for a pattern recognition network by passing a function handle to the "shapley" function. The function handle needs to output the score for the class of interest. Note that "shapley" expects inputs and outputs for the function handle to be row vectors rather than column vectors, so some transposes are needed for the function to work as expected.
Below is an example of how to do this using the Fisher Iris data
% Train a neural network on the iris data
[x,t] = iris_dataset;
net = patternnet(10);
net = train(net,x,t);
% Choose an observation to explain. We need its class as an index.
x1 = x(:,1);
t1 = find(t(:,1));
% Plot Shapley values. For Setosa (the first class) the petal length (x3)
% is usually the most informative feature.
explainer = shapley( ...
@(x)predictScoreForSpecifiedClass(net,x,t1), ...
x', "QueryPoint", x1' );
plot(explainer)
% Helpers
function score = predictScoreForSpecifiedClass(net, x, classIndex)
Y = net(x');
score = Y(classIndex,:)';
end
0 commentaires
Plus de réponses (0)
Voir également
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!