Precision-Recall curve for Multiclass classification
Afficher commentaires plus anciens
I have been trying hard to find any document or example related to ploting precision-recall curve for multiclass classification. But it seems like there is no way to do that. How would I make a precision-recall curve for my model.
Following is the code I use to get confusion matrix
fpath = 'E:\Research Data\Four Classes';
testData = fullfile(fpath, 'Test');
% %
testDatastore = imageDatastore(testData, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
allclass = [];
for i = 1:length(testDatastore.Labels)
I = readimage(testDatastore, i);
class = classify(incep, I);
allclass = [allclass class];
end
predicted = allclass';
% figure
% plotconfusion(testDatastore.Labels, predicted)
3 commentaires
the cyclist
le 28 Juil 2023
Before thinking about this as a MATLAB question, there is a conceptual question.
Ordinarily, a precision-recall curve is used for a binary classification, where there is a clear definition of "true positive", "false positive", "true negative", and "false negative". These numbers are the foundation of ROC and precision-recall curves.
But for multi-class classification, it is not so straightforward. You'll need to decide what you want to plot (and why). One possibility is to do classification of "one versus others", and then make a P-R curve for each class.
I suggest that you do some reading about the conceptual part, before you try to do the MATLAB problem. I found this article to be useful.
Laraib
le 28 Juil 2023
help roc
Réponses (1)
Given a multiclass classification problem, you can create a Precision-Recall curve for each class by considering the one-vs-all binary classification problem for each class. The Precision-Recall curves can be built with:
- The rocmetrics function (introduced in R2022a). See https://www.mathworks.com/help/stats/rocmetrics.html and the specific Precision-Recall curve example: openExample('stats/SpecifyScoresAsVectorExample')
- The perfcurve function (introduced in R2009a). See https://www.mathworks.com/help/stats/perfcurve.html . You can see example code at https://www.mathworks.com/matlabcentral/answers/217604-how-to-calculate-precision-and-recall-in-matlab
The documentation for the plot method of the rocmetrics object https://www.mathworks.com/help/stats/rocmetrics.plot.html has another Precision-Recall curve example: openExample('stats/PlotOtherPerformanceCurveExample')
This comment applies to ROC curves, but not Precision-Recall curves: For multiclass problems, the "plot" method of the rocmetrics object also has the ability to create ROC curves from averaged metrics using the "AverageROCType" Name-Value Argument, and the "average" method of the rocmetrics object can be used to calculate these average metrics https://www.mathworks.com/help/stats/rocmetrics.average.html . An example of an average ROC curve is here: openExample('stats/PlotAverageROCCurveExample'). The averaging options include micro, macro, and weighted-macro.
In order to build the rocmetrics object with the rocmetrics function, or to use the perfcurve function, you will need the scores from the classify function.
Here is a precision-recall curve example for a tree model built with fisheriris data.
t=readtable("fisheriris.csv");
response="Species";
% Create set of models for 5-fold cross-validation
cvmdl=fitctree(t,response,KFold=5);
% Get cross-validation predictions and scores
[yfit,scores]=kfoldPredict(cvmdl);
% View confusion matrix
% The per-class Precision can be seen in the blue boxes in the
% column-summary along the bottom.
% The per-class Recall can be seen in the blue boxes in the row summary
% along the right side.
cm=confusionchart(t{:,response},yfit);
cm.ColumnSummary='column-normalized';
cm.RowSummary='row-normalized';
% Calculate precision, recall, and F1 per-class from the raw confusion
% matrix counts
% Precision = TP/(TP+FP); Recall = TP/(TP+FN);
% F1score is the harmonic mean of Precision and Recall.
% The cm.Normalization needs to be set to 'absolute', so that the values
% are raw counts.
counts = cm.NormalizedValues;
precisionPerClass= diag(counts)./ (sum(counts))';
recallPerClass = diag(counts)./ (sum(counts,2));
F1PerClass=2.*diag(counts) ./ ((sum(counts,2)) + (sum(counts))');
% Create rocmetrics object
% Add the metric "PositivePredictiveValue", which is Precision.
% The metric "TruePositiveRate", which is Recall, is in the Metrics by default.
rocObj=rocmetrics(t{:,response}, scores, cvmdl.ClassNames, ...
AdditionalMetrics="PositivePredictiveValue");
% For illustration, focus on metrics for one class, virginica
classindex=3;
% Plot the precision-recall curve for one class, given by classindex
% (By default, the rocmetrics plot function will plot the one-vs-all PR curves
% for all of the classes at once.)
r=plot(rocObj, YAxisMetric="PositivePredictiveValue", ...
XAxisMetric="TruePositiveRate", ...
ClassNames=rocObj.ClassNames(classindex));
hold on;
xlabel("Recall");ylabel("Precision"); title('Precision-Recall Curve');
% Place the operating point on the figure
scatter(recallPerClass(classindex),precisionPerClass(classindex),[],r.Color,"filled");
% Update legend
legend(strcat(string(rocObj.ClassNames(classindex))," one-vs-all P-R curve"), ...
strcat(string(rocObj.ClassNames(classindex)), ...
sprintf(' one-vs-all operating point\nP=%4.1f%%, R=%4.1f%%, F1=%4.1f%%', ...
100*precisionPerClass(classindex),100*recallPerClass(classindex), ...
100*F1PerClass(classindex))));
hold off;
Catégories
En savoir plus sur Visualization and Interpretability dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

