Convolutional Neural Network for traffic signs classification.
Afficher commentaires plus anciens
Hello,
I want to create traffic sign classifier with CNN based on dataset from GTSRB. I made my net only for 11 classes. When I try to classify picture of sign which is in my dataset then I have almost everytime 100% accuracy, but when I try to cfassily sign which isn't in my dataset ( for example speed limit 20 km/h image downloaded from google ) the prediction is mostly incorrect.
How can I improve my network ?
Dataset : https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/published-archive.html
Code :
clc;clear;
DatasetPath = fullfile('C:\Users\Pulpit\Splotowe Sieci Neuronowe\GTSRB\fl');
Data = imageDatastore(DatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(Data,0.7,'randomize');
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-20,20], ...
'RandXTranslation',[-1 1], ...
'RandYTranslation',[-1 1], ...
'RandXReflection', false, ...
'RandYReflection', true ...
);
imageSize = [32 32 3];
augimdsTrain = augmentedImageDatastore(imageSize,imdsTrain,'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(imageSize(1:2),imdsValidation);
%%
layers = [
imageInputLayer([32 32 3])
convolution2dLayer(3,64,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,64,'Stride' ,1,'Padding','same' )
batchNormalizationLayer
reluLayer
maxPooling2dLayer(1, 'Stride', 1);
convolution2dLayer(3,64,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,64,'Stride' ,1,'Padding','same' )
batchNormalizationLayer
reluLayer
maxPooling2dLayer(1, 'Stride', 1);
convolution2dLayer(3,128,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,128,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,128,'Stride' ,1,'Padding','same' )
batchNormalizationLayer
reluLayer
maxPooling2dLayer(1, 'Stride', 1);
convolution2dLayer(3,256,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,256,'Stride' ,1,'Padding','same' )
convolution2dLayer(3,256,'Stride' ,1,'Padding','same' )
batchNormalizationLayer
reluLayer
maxPooling2dLayer(1, 'Stride', 1);
fullyConnectedLayer(64)
fullyConnectedLayer(32)
fullyConnectedLayer(11)
softmaxLayer
classificationLayer];
%miniBatchSize = 32;
%valFrequency = floor(numel(imdsValidation.Files)/miniBatchSize);
%%
%{
options = trainingOptions('sgdm', ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropFactor',0.2, ...
'LearnRateDropPeriod',5, ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',valFrequency, ...
'InitialLearnRate',0.001, ...
'MaxEpochs',10, ...
'MiniBatchSize',miniBatchSize, ...
'Plots','training-progress');
%}
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
%%
convnet = trainNetwork(augimdsTrain,layers,options);
%%
YTest = imdsValidation.Labels;
[YPred,prob] = classify(convnet,augimdsValidation);
accuracy = sum(YPred == YTest)/numel(YTest)
plotconfusion(YTest,YPred)
idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
subplot(2,2,i)
I = readimage(imdsValidation,idx(i));
imshow(I)
label = YPred(idx(i));
title(string(label) + ", " + num2str(100*max(prob(idx(i),:)),3) + "%");
end
2 commentaires
Omran Adnanoglu
le 13 Avr 2020
Hi Krystian,
I am working on the same project, i still in the very beginning and i am facing a problem while reading the data with the given code like this :
Error using imread (line 438)
End of file reached too early.
Error in TrainTrafficSigns (line 28)
Img = imread(ImgFile);
Can you please share me the part of code where you read the data ?
Krystian P
le 19 Avr 2020
Modifié(e) : Krystian P
le 19 Avr 2020
Réponses (1)
ZHI-RUI LIN
le 7 Juin 2022
0 votes
Excuse me, i would like to ask how you loaded the GTSRB data set (ppm file) into matlab, thank you !
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!