5-fold cross validation with neural networks (function approximation)
Afficher commentaires plus anciens
I have matlab code which implement hold out cross validation (attached). I am looking for help to perform 5-fold cross validation on the same model architecture. Please help me to figure this out. Thank you.
%% Data
X = x'; % input Always stays same
Y = yte'; % target
%% Model parameter change
% Choose a Training Function ('trainlm', 'trainscg','traingdx')
trainFcn = 'trainlm';
% Choose a Neuron in hidden layers
hiddenLayerSize =17;
% Choose an activation fucntion ( logsig,tansig, purelin)
net.layers{1}.transferFcn = 'logsig'; % hidden layer
net.layers{2}.transferFcn = 'poslin'; % output layer
% Choose an evaluation metrics (mae, mse)
net.performFcn = 'mse';
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', 'plotregression', 'plotfit'};
% view network
net = fitnet(hiddenLayerSize,trainFcn);
%view(net)
%% Data-processing
net.input.processFcns = {'removeconstantrows','mapstd'}; % Input: remove const values and map values between [0 to 1]
net.output.processFcns = {'removeconstantrows','mapstd'}; % Input: remove const values and map values between [0 to 1]
%% Data split (0.7,0.15 & 0.15)
net.divideFcn = 'dividerand'; % randonmly
net.divideMode = 'sample'; % each obs as sample
net.divideParam.trainRatio = 70/100; % train
net.divideParam.valRatio = 15/100; % test
net.divideParam.testRatio = 15/100; % validation
%% Train a neural network
[net,tr] = train(net,X,Y);
% net- gives train model
% tr-training records
%% network performance
figure(1), plotperform(tr) % Plot network performance
figure(2), plottrainstate(tr) % Plot training state values.
%% Error and R2
Ytest = net(X); % prediction on X
e = gsubtract(Y,Ytest); % subtraction( Yactual-ypred)
MSE = perform(net, Y,Ytest); % Calculate network performance = mae or mse value
MAE=mae(net, Y,Ytest);
%% Regression performance
trOut = Ytest(tr.trainInd); %traing output-predicted
trTarg = Y(tr.trainInd); % training target-Actual
vOut = Ytest(tr.valInd); % val output
vTarg = Y(tr.valInd); % val target
tsOut = Ytest(tr.testInd); % test output
tsTarg = Y(tr.testInd); %test target
figure(4), plotregression(trTarg, trOut, 'Train', vTarg, vOut, 'Validation', tsTarg, tsOut, 'Testing',Y,Ytest,'All')
% R2
R2_Train= regression(trTarg, trOut)^2;
R2_Val= regression(vTarg, vOut)^2;
R2_Test= regression(tsTarg, tsOut)^2;
R2_all= regression(Y,Ytest)^2;
%figure(3), ploterrhist(e) % Plot error histogram
1 commentaire
Chetan Badgujar
le 5 Mar 2021
Réponse acceptée
Plus de réponses (0)
Catégories
En savoir plus sur Define Shallow Neural Network Architectures dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!