Using R^2 results to optimize number of neurons in hidden layer
3 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
I am trying to find the optimal number of neurons in a hidden layer following Greg Heath's method of looping over the candidate number of neurons, with an several trials per number of neurons. The resulting R-squared statistic is plotted below. I would like some help in using the R-squared results to pick the optimal number of neurons. Is this right:
- 23 neurons is a good choice, since all the trials exceed the desired threshold of R-squared > 0.995. To make a prediction, I could pick any of the 10 trial nets that were generated with 23 neurons.
- 15 neurons is a bad choice because sometimes the threshold is not met
- More than 23 neurons is a bad choice because the network will be slower to run
- More than 33 neurons is a very bad choice due to overfitting (the code below limits the number of neurons to avoid overfitting).
- To save time, I could stop my loop after say 28 neurons, since by then I would have had 5 instances (23-28 neurons) in which all 10 trials resulted in R-squared values above my threshold.
Here is the code, based on Greg's examples with comments for learning.
%%Get data
[x, t] = engine_dataset; % Existing dataset
% Described here, https://www.mathworks.com/help/deeplearning/ug/choose-a-multilayer-neural-network-training-function.html#bss4gz0-26
[ I N ] = size(x); % Size of network inputs, nVar x nExamples
[ O N ] = size(t); % Size of network targets, nVar x nExamples
fractionTrain = 75/100; % How much of data will go to training
fractionValid = 15/100; % How much of data will go to validation
fractionTest = 15/100; % How much of data will go to testing
Ntrn = N - round(fractionTest*N + fractionValid*N); % Number of training examples
Ntrneq = Ntrn*O; % Number of training equations
% MSE for a naïve constant output model that always outputs average of target data
MSE00 = mean(var(t',1));
%%Find a good value for number of neurons H in hidden layer
Hmin = 0; % Lowest number of neurons H to consider (a linear model)
% To avoid overfitting with too many neurons, require that Ntrneq > Nw==> H <= Hub (upper bound)
Hub = -1+ceil( (Ntrneq-O) / ( I+O+1) ); % Upper bound for Ntrneq >= Nw
Hmax = floor(Hub/10); % Stay well below the upper bound by dividing by 10
dH = 1;
% Number of random initial weight trials for each candidate h
Ntrials = 10;
% Randomize data division and initial weights
% Choose a repeatable initial random number seed
% so that you can reproduce any individual design
rng(0);
j=0; % Loop counter for number of neurons in hidden layer
for h=Hmin:dH:Hmax % Number of neurons in hidden layer
j = j+1;
if h == 0 % Linear model
net = fitnet([]); % Fits a linear model
Nw = (I+1)*O; % Number of unknown weights
else % Nonlinear neural network model
net = fitnet(h); % Fits a network with h nodes
Nw = (I+1)*h + (h+1)*O; % Number of unknown weights
end
%%Divide up data into training, validation, testing sets
% Data presented to network during training. Network is adjusted according to its error.
net.divideParam.trainRatio = fractionTrain;
% Data used to measure network generalization, and to halt training when generalization stops improving
net.divideParam.valRatio = fractionValid;
% Data used as independent measure of network performance during and after training
net.divideParam.testRatio = fractionTest;
%%Loop through several trials for each candidate number of neurons
% to increase statistical reliability of results
for i=1:Ntrials
% Configure network inputs and outputs to best match input and target data
net = configure(net, x, t);
% Train network
[net, tr, y, e] = train(net,x,t); % Error e is target - output
% R-squared for normalized MSE. Fraction of target variance modeled by net
R2(i, j) = 1 - mse(e)/MSE00;
end
end
%%Plot R-squared
figure;
% Grid data in preparation for plotting
[nNeuronMat, trialMat] = meshgrid(Hmin:dH:Hmax, 1:Ntrials);
% Color of dot indicates the R-squared value
h1 = scatter(nNeuronMat(:), trialMat(:), 20, R2(:), 'filled');
hold on;
% Design is successful if it can account for 99.5% of mean target variance
iGood = R2 > 0.995;
% Circle the successful values
h2 = plot(nNeuronMat(iGood), trialMat(iGood), 'ko', 'markersize', 10);
xlabel('# Neurons'); ylabel('Trial #');
h3 = colorbar;
set(get(h3,'Title'),'String','R^2')
title({'Fraction of target variance that is modeled by net', ...
'Circled if it exceeds 99.5%'});
3 commentaires
Greg Heath
le 26 Sep 2018
KAE: You say you typically run through a reasonable
value, here H=34, before seeking to minimize H.
What do the 'reasonable value' results tell you?
Maybe a good choice of a R-square threshold to use
later in the H minimization, so for your example we
might require the smallest H with Rsq>= 0.9975?
GREG: NO.
My training goal for all of the MATLAB examples has
always been the more realistic
NMSE = mse(t-y)/mean(var(t',1)) <= 0.01
or , equivalently
Rsq = 1- NMSE >= 0.99
That is not to say better values can not be acheived.
Just that this is a more reasonable way to begin
before trying to optimize the final result.
%Using the command
>> help nndatasets
% yields
simplefit_dataset - Simple fitting dataset.
abalone_dataset - Abalone shell rings dataset.
bodyfat_dataset - Body fat percentage dataset.
building_dataset - Building energy dataset.
chemical_dataset - Chemical sensor dataset.
cho_dataset - Cholesterol dataset.
engine_dataset - Engine behavior dataset.
vinyl_dataset - Vinyl bromide dataset.
% then
>> SIZE1 = size(simplefit_dataset)
SIZE2 = size(abalone_dataset)
SIZE3 = size(bodyfat_dataset)
SIZE4 = size(building_dataset)
SIZE5 = size(chemical_dataset)
SIZE6 = size(cho_dataset)
SIZE7 = size(engine_dataset)
SIZE8 = size(vinyl_dataset)
% yields
SIZE1 = 1 94
SIZE2 = 8 4177
SIZE3 = 13 252
SIZE4 = 14 4208
SIZE5 = 8 498
SIZE6 = 21 264
SIZE7 = 2 1199
SIZE8 = 16 68308
% Therefore, FOR OBVIOUS REASONS, most of my MATLAB tutorial-type investigations are restricted to N < 500 (i.e., simplefit, bodyfat, chemical and cho)
Hope this helps.
Greg
Réponse acceptée
Greg Heath
le 23 Sep 2018
The answer is H = 12. H >= 13 is overfitting.
If that makes you queasy, average the output of the H=12 net with the 2 with H = 13.
NOTE: Overfitting is not "THE" problem. "THE PROBLEM" is:
OVERTRAINING AN OVERFIT NET.
Hope this helps.
*Thank you for formally accepting my answer*
Greg
3 commentaires
Greg Heath
le 26 Sep 2018
Modifié(e) : Greg Heath
le 26 Sep 2018
There are TWO MAINPOINTS
1. H <= Hub PREVENTS OVERFITTING, i.e. having more unknown weights than training equations.
2. min(H) subject to R2 > 0.99 STABILIZES THE SOLUTION by REDUCING the uncertainty in weight estimates.
Hope this helps.
Greg
Plus de réponses (0)
Voir également
Catégories
En savoir plus sur Modeling and Prediction with NARX and Time-Delay Networks dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!