Training shallow neural network (no hidden layer) for MNIST classification. Getting low accuracy
14 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
Achint Kumar
le 25 Nov 2019
Réponse apportée : Srivardhan Gadila
le 13 Jan 2020
I am trying to implement classification of MNIST dataset without any hidden layers. So the data is input as 784x1 vector and output is 10x1 vector (after one-hot encoding).
The problem is that once I train the network, I am getting very low accuracy (~1%). The reason for that is unclear to me. My guess is that my update rules are incorrect. After training, the output vector on test data (a_test below) is not one-hot anymore but has multiple 1's. I am not able to figure out where I am going wrong.
alp = 0.0001; % learning rate
epoch = 50; % number of iterations
dim_train = 784; % dimension of input data vector
for itr = 1: epoch
z = W*image_train+repmat(b',n_train,1)'; % image_train is a 784*60000 data matrix and W is 10*784 weight matrix and b is bias
a = 1./(1+exp(-z));
a = a+0.001; % to avoid zero inside log when calculating cross entropy loss
a_flg = sum(a);
for i = 1:n_train
a(:,i) = a(:,i)/a_flg(i); % normalizing output
end
L = -sum( Y_train.*log(a), 'all' ); % calculation loss
dLdW = 1/dim_train*(a-Y_train)*image_train'; % calculating dL/dW
dLdb = 1/dim_train*(a-Y_train)*ones(n_train,1);% calculating dL/db
W = W - alp*dLdW; % updating weights (gradient descent)
b = b - alp*dLdb; % updating bias (gradient descent)
loss(itr) = 1/n_train*L;
end
%% Testing
a_test = 1./(1+exp(-(W*image_test+repmat(b',n_test,1)')));
for i = 1:n_test
a_test(~(a_test(:,i)==max(a_test(:,i))),i)=0;
end
0 commentaires
Réponse acceptée
Srivardhan Gadila
le 13 Jan 2020
The network architecture defined without any hidden layer may not be able to learn to classify the digits beloging to 10 classes. It is reccommended to use few hidden layers.
You can also refer to Define Custom Deep Learning Layers & Create Simple Deep Learning Network for Classification
0 commentaires
Plus de réponses (0)
Voir également
Catégories
En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!