Effacer les filtres
Effacer les filtres

How to plot averaged ROC curve?

5 vues (au cours des 30 derniers jours)
Ishfaque Ahmed
Ishfaque Ahmed le 20 Avr 2022
Commenté : Adam Danz le 25 Avr 2022
I am trying to plot ROC curve for my model for multiple iterations. The curve are not at same locations so I want to plot one averaged ROC from all 10 ROC curves. Please suggest me the solution.
  1 commentaire
Chunru
Chunru le 21 Avr 2022
You can interpolate each curve on the same grid and then perform average.

Connectez-vous pour commenter.

Réponse acceptée

Chunru
Chunru le 21 Avr 2022
% Create sample data
numPoints = 50;
nCurves = 10;
x = sort(rand(numPoints, nCurves));
y = (sort(rand(numPoints, nCurves))).^(1/4);
plot(x, y);
grid on;
hold on;
% same grid
x0 = linspace(0, 1, 100);
% interpolation
yinterp = zeros(length(x0), nCurves);
for i=1:nCurves
yinterp(:, i) = interp1(x(:,i), y(:,i), x0, 'linear', 'extrap');
end
% Now average together
meany = mean(yinterp, 2);
% Now plot
hold on;
plot(x0, meany, 'LineWidth', 2);
  1 commentaire
Ishfaque Ahmed
Ishfaque Ahmed le 21 Avr 2022
Thank you very much.

Connectez-vous pour commenter.

Plus de réponses (1)

Image Analyst
Image Analyst le 21 Avr 2022
Modifié(e) : Image Analyst le 21 Avr 2022
Try this:
% Create sample data because the original poster didn't upload theirs.
numPoints = 30;
x1 = sort(rand(1, numPoints));
x2 = sort(rand(1, numPoints));
x3 = sort(rand(1, numPoints));
x4 = sort(rand(1, numPoints));
x5 = sort(rand(1, numPoints));
x6 = sort(rand(1, numPoints));
x7 = sort(rand(1, numPoints));
x8 = sort(rand(1, numPoints));
x9 = sort(rand(1, numPoints));
x10 = sort(rand(1, numPoints));
y1 = sort(rand(1, numPoints));
y2 = sort(rand(1, numPoints));
y3 = sort(rand(1, numPoints));
y4 = sort(rand(1, numPoints));
y5 = sort(rand(1, numPoints));
y6 = sort(rand(1, numPoints));
y7 = sort(rand(1, numPoints));
y8 = sort(rand(1, numPoints));
y9 = sort(rand(1, numPoints));
y10 = sort(rand(1, numPoints));
plot(x1, y1, '-');
hold on;
plot(x2, y2, '-');
plot(x3, y3, '-');
plot(x4, y4, '-');
plot(x5, y5, '-');
plot(x6, y6, '-');
plot(x7, y7, '-');
plot(x8, y8, '-');
plot(x9, y9, '-');
plot(x10, y10, '-');
grid on;
hold on;
%========================================================================
% Since you have your own data you'd start here
% and NOT create the sample data above.
allx = sort([x1,x2,x3,x4,x5,x6,x7,x8,x9,x10], 'ascend');
% Then interpolate all the other curves so they're on a common x axis.
y1a = interp1(x1, y1, allx);
y2a = interp1(x2, y2, allx);
y3a = interp1(x3, y3, allx);
y4a = interp1(x4, y4, allx);
y5a = interp1(x5, y5, allx);
y6a = interp1(x6, y6, allx);
y7a = interp1(x7, y7, allx);
y8a = interp1(x8, y8, allx);
y9a = interp1(x9, y9, allx);
y10a = interp1(x10, y10, allx);
% Get all y together in one matrix.
allY = [y1a;y2a;y3a;y4a;y5a;y6a;y7a;y8a;y9a;y10a];
% Find out how many curves have valid, non-nan values at each x location.
counts = sum(~isnan(allY), 1);
% Now set nans to zero so we can sum the values and not get a nan if one of the curves is nan for some x value.
allY(isnan(allY)) = 0;
% Since some y are nan (which happens outside the x range where they were originally defined)
% we can't use mean(ally, 1) to get the mean value because we'd be averaging in zeros.
% So we need to sum the ally array vertically to get the sum of the non-nan values,
% and then sum the counts array vertically to find out
% how many signals were not nan for those x values.
% Then we can divide the sum by the counts to get the true mean.
meany = sum(allY, 1) ./ sum(counts, 1);
% Now plot the mean as a thick black curve.
hold on;
plot(allx, meany, 'k-', 'LineWidth', 4);
title('Thick black line is the mean of all curves')
Note how the plot gets a little wiggly near the ends as the number of valid curves (non-nan values) gets fewer and so the mean gets closer to the valid remaining curves. For example let's say the after x = 0.9 there are only 5 curves with non-nan values, not the full 10. So there you'd want to average only 5 curves, not all 10. So in the picture above, see close to 1, only the yellow curve has valid x values out that far, so the mean will equal the yellow curve's y value there. It's for this reason that you can't just simply use the mean() function and you have to divide the sum by the count (because the count changes). Does that make sense?
  3 commentaires
Image Analyst
Image Analyst le 21 Avr 2022
Well I guess you could compute the standard deviation at every x location and then get two curves
  1. the average curve plus the locally varying standard deviation
  2. the average curve minus the locally varying standard deviation.
Then plot those curves. One will be above the mean curve and one will be below it. Where you have only one curve (at the outside ends) the standard deviation will be zero there of course.
Adam Danz
Adam Danz le 25 Avr 2022
I wonder if curve fitting would useful. Then you could get error estimates of the fit parameters and plot the smooth fit and the range of error.

Connectez-vous pour commenter.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by