mnist classification using batch method
Afficher commentaires plus anciens
Hi. I want to train a neural network with mnist database using batch method. I use below code but my accuracy is very low. but I think the code is correct. can any one help me please?
function [hiddenWeights, outputWeights, error] = train_network_batch(numberOfHiddenUnits, input, target, epochs, batchSize, learningRate,lambda)
% The number of training vectors.
trainingSetSize = size(input, 2);
% Input vector has 784 dimensions.
inputDimensions = size(input, 1);
% We have to distinguish 10 digits.
outputDimensions = size(target, 1);
% Initialize the weights for the hidden layer and the output layer.
% hiddenWeights = randn(NHiddenUnit, inputDimensions)*1/sqrt(size(input, 1));
% outputWeights = randn(outputDimensions, NHiddenUnit)*1/sqrt(size(input, 1));
hiddenWeights = rand(numberOfHiddenUnits, inputDimensions);
outputWeights = rand(outputDimensions, numberOfHiddenUnits);
hiddenWeights = hiddenWeights./size(hiddenWeights, 2);
outputWeights = outputWeights./size(outputWeights, 2);
hiddenWeights_store = hiddenWeights;
outputWeights_store = outputWeights;
n = zeros(batchSize,1);
validation_count=0;
validation_accuracy=0;
figure; hold on;
%batch method
for t = 1: epochs
for k = 1: batchSize
% Select which input vector to train on.
n(k) = floor(rand(1)*trainingSetSize + 1);
% n(k) =k;
% Propagate the input vector through the network.
inputVector = input(:, n(k));
hiddenActualInput = hiddenWeights*inputVector;
hiddenOutputVector = linear_func(hiddenActualInput);
outputActualInput = outputWeights*hiddenOutputVector;
outputVector = linear_func(outputActualInput);
targetVector = target(:, n(k));
% Backpropagate the errors.
outputDelta = dlinear_func(outputActualInput).*(outputVector - targetVector);
hiddenDelta = dlinear_func(hiddenActualInput).*(outputWeights'*outputDelta);
% outputWeights_store = outputWeights_store -(learningRate*lambda/batchSize).*outputWeights- learningRate.*outputDelta*hiddenOutputVector'; hiddenWeights_store = hiddenWeights_store -(learningRate*lambda/batchSize).*hiddenWeights-learningRate.*hiddenDelta*inputVector';
% outputWeights =(1-(learningRate*lambda/batchSize)).*outputWeights - learningRate.*outputDelta*hiddenOutputVector'; % hiddenWeights = (1-(learningRate*lambda/batchSize)).*hiddenWeights - learningRate.*hiddenDelta*inputVector';
end;
outputWeights=outputWeights+(outputWeights_store./batchSize);
hiddenWeights=hiddenWeights+(hiddenWeights_store./batchSize);
outputWeights_store=0;
hiddenWeights_store=0;
% %*********************************end of batch method*************** % Calculate the error for plotting. error = 0; for k = 1: batchSize inputVector = input(:, n(k)); targetVector = target(:, n(k));
error = error + norm(linear_func(outputWeights*linear_func(hiddenWeights*inputVector)) - targetVector, 2);
end;
error = error/batchSize;
plot(t, error,'*');
title(['MSE_ batch','NH= ',num2str(numberOfHiddenUnits),',',' alfa=',num2str(learningRate),' ,epoch=',num2str(epochs)]);
xlabel('epoch');
ylabel('cost');
inputValues=load('validation.mat');
inputValues=inputValues.v;
labels=load('label.mat');
labels=labels.l;
[correctlyClassified, classificationErrors]=validation_network(hiddenWeights,outputWeights,inputValues',labels);
correctlyClassified=correctlyClassified/10000;
if correctlyClassified<= validation_accuracy
validation_count=validation_count+1;
else
validation_count=0;
end
if validation_count>7
break;
end
validation_accuracy=correctlyClassified;
end;
end
Réponse acceptée
Plus de réponses (0)
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!