Effacer les filtres
Effacer les filtres

training options in a MLP created with deep learning toolbox

9 vues (au cours des 30 derniers jours)
Alberto Tellaeche
Alberto Tellaeche le 21 Nov 2022
Hi all,
I am trying to train a stadard MLP created with deep learning toolbox to classify the digits in the MNIST dataset.
As for my example I do not want to use a CNN, I have flattened the image data, creating with each 28*28 image an input vector of 784 elements.
My code is as follows:
clear; clc;
filenameImagesTrain = 'train-images-idx3-ubyte.gz';
filenameLabelsTrain = 'train-labels-idx1-ubyte.gz';
filenameImagesTest = 't10k-images-idx3-ubyte.gz';
filenameLabelsTest = 't10k-labels-idx1-ubyte.gz';
XTrain = processImagesMNIST(filenameImagesTrain);
YTrain = processLabelsMNIST(filenameLabelsTrain);
XTest = processImagesMNIST(filenameImagesTest);
YTest = processLabelsMNIST(filenameLabelsTest);
sizeX = 28;
sizeY = 28;
XVectorTrain = reshape(XTrain, 28*28, 60000);
multilayer_perceptron = [
sequenceInputLayer(sizeX*sizeY,"Name","input")
fullyConnectedLayer(32,"Name","capa 1")
reluLayer("Name","relu")
fullyConnectedLayer(10,"Name","capa 2")
softmaxLayer("Name","softmax")
classificationLayer("Name","classoutput")];
plot(layerGraph(multilayer_perceptron));
options = trainingOptions("sgdm","Plots","training-progress", ...
"SequenceLength",sizeY*sizeX,...
"MaxEpochs",40,"MiniBatchSize",1, ...
"InitialLearnRate", 0.005,"Momentum",0.9, ...
"ExecutionEnvironment","auto");
%"MaxEpochs",40,"MiniBatchSize",8, ...
%"Shuffle","every-epoch", ...
network = trainNetwork(XVectorTrain,categorical(YTrain),multilayer_perceptron,options);
However, I do not get the network to train:
And this is the result until the end.
I know this example can be done, I have seen similar examples in Keras, but I can not make it work in MATLAB.
I would be very grateful if someone could help me with this issue,
Best regards,

Réponses (1)

Sai Kiran
Sai Kiran le 22 Déc 2022
Hi,
The function trainNetwork returns the trained model. The returned model gets stored in the variable network, and can be used for predicting the results.
Please refer to the following documentation to know how to use the trained model in further steps.
I hope it resolves your query.
Regards,
Sai Kiran Ratna

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Produits


Version

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by