La traduction de cette page n'est pas à jour. Cliquez ici pour voir la dernière version en anglais.
Convertir un réseau de classification en réseau de régression
Cet exemple montre comment convertir un réseau de classification entraîné en réseau de régression.
Des réseaux de classification d’images préentraînés ont été entraînés sur plus d’un million d'images et peuvent classer des images dans 1 000 catégories d’objets, par exemple un clavier, une tasse à café, un crayon et de nombreux animaux. Les réseaux ont appris des représentations avec de nombreuses caractéristiques pour une grande variété d’images. Le réseau utilise une image comme entrée puis produit une étiquette pour l’objet dans l’image avec les probabilités pour chaque catégorie d’objet.
L’apprentissage par transfert est communément utilisé dans les applications de Deep Learning. Vous pouvez utiliser un réseau préentraîné comme point de départ pour apprendre une nouvelle tâche. Cet exemple montre comment utiliser un réseau de classification préentraîné et le réentraîner pour des tâches de régression.
L'exemple charge l’architecture d’un réseau de neurones à convolution préentraîné pour la classification, remplace les couches pour la classification et réentraîne le réseau pour prédire les angles des chiffres écrits à la main et inclinés.
Charger un réseau préentraîné
Chargez le réseau préentraîné à partir du fichier d'aide digitsClassificationConvolutionNet.mat
. Ce fichier contient un réseau de classification qui classifie les chiffres écrits à la main.
load digitsClassificationConvolutionNet
layers = net.Layers
layers = 13x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images 2 'conv_1' 2-D Convolution 10 3x3x1 convolutions with stride [2 2] and padding [0 0 0 0] 3 'batchnorm_1' Batch Normalization Batch normalization with 10 channels 4 'relu_1' ReLU ReLU 5 'conv_2' 2-D Convolution 20 3x3x10 convolutions with stride [2 2] and padding [0 0 0 0] 6 'batchnorm_2' Batch Normalization Batch normalization with 20 channels 7 'relu_2' ReLU ReLU 8 'conv_3' 2-D Convolution 40 3x3x20 convolutions with stride [2 2] and padding [0 0 0 0] 9 'batchnorm_3' Batch Normalization Batch normalization with 40 channels 10 'relu_3' ReLU ReLU 11 'gap' 2-D Global Average Pooling 2-D global average pooling 12 'fc' Fully Connected 10 fully connected layer 13 'softmax' Softmax softmax
charger les données
L'ensemble de données contient des images synthétiques de chiffres écrits à la main, ainsi que les angles de rotation (en degrés) correspondants appliqués à chaque image.
Chargez les images d’apprentissage et de test sous forme de tableaux 4-D à partir des fichiers de support DigitsDataTrain.mat
et DigitsDataTest.mat
. Les variables anglesTrain
et anglesTest
correspondent aux angles de rotation en degrés. Les jeux de données d’apprentissage et de test contiennent 5 000 images chacun.
load DigitsDataTrain load DigitsDataTest
Affichez 20 images d’apprentissage aléatoires avec imshow
.
numTrainImages = numel(anglesTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
Remplacer les couches finales
Les couches de convolution du réseau extraient les caractéristiques de l’image que la dernière couche entraînable a utilisée, pour classer l’image en entrée. La couche 'fc'
contient des informations sur la manière de combiner les caractéristiques extraites par le réseau en probabilités de classe. Pour réentraîner un réseau préentraîné pour la régression, remplacez cette couche et la couche softmax suivante par une nouvelle couche adaptée à la tâche.
Remplacez la dernière couche entièrement connectée par une couche entièrement connectée de taille 1 (nombre de réponses).
numResponses = 1; layer = fullyConnectedLayer(numResponses,Name="fc"); net = replaceLayer(net,"fc",layer)
net = dlnetwork with properties: Layers: [13x1 nnet.cnn.layer.Layer] Connections: [12x2 table] Learnables: [14x3 table] State: [6x3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 0 View summary with summary.
Supprimez la couche softmax.
net = removeLayers(net,"softmax");
Ajuster les facteurs de taux d’apprentissage de couches
Le réseau est maintenant prêt à être réentraîné sur les nouvelles données. Si vous le souhaitez, vous pouvez ralentir l’apprentissage des poids des couches précédentes du réseau en augmentant le taux d’apprentissage de la nouvelle couche entièrement connectée et en réduisant le taux d’apprentissage global au moment où vous spécifiez les options d’apprentissage.
Augmentez les taux d’apprentissage des paramètres de couches entièrement connectées en leur appliquant un facteur défini avec la fonction setLearnRateFactor
.
net = setLearnRateFactor(net,"fc","Weights",10); net = setLearnRateFactor(net,"fc","Bias",10);
Spécifier les options d’apprentissage
Spécifiez les options d’apprentissage. Le choix des options nécessite une analyse empirique. Pour explorer différentes configurations dans les options d’apprentissage au cours de vos expérimentations, vous pouvez utiliser l’application Experiment Manager.
Spécifiez un taux d’apprentissage réduit de 0,0001.
Affichez la progression de l’apprentissage dans un tracé.
Désactivez la sortie en clair.
options = trainingOptions("sgdm",... InitialLearnRate=0.001, ... Plots="training-progress",... Verbose=false);
Entraîner le réseau de neurones
Entraînez le réseau de neurones avec la fonction trainnet
. Pour la régression, utilisez la perte d’erreur quadratique moyenne. Par défaut et selon disponibilité, la fonction trainnet
utilise un GPU. L’utilisation d’un GPU nécessite une licence Parallel Computing Toolbox™ et un dispositif GPU supporté. Pour plus d'information sur les dispositifs supportés, veuillez consulter Exigences de calcul du GPU (Parallel Computing Toolbox). Sinon, la fonction utilise le CPU. Pour spécifier l’environnement d’exécution, utilisez l’option d’apprentissage ExecutionEnvironment
.
net = trainnet(XTrain,anglesTrain,net,"mse",options);
Tester le réseau
Testez la performance du réseau en évaluant la précision sur les données de test.
Utilisez predict
pour prédire les angles de rotation appliqués aux images de validation.
YTest = predict(net,XTest);
Visualisez les prédictions dans un diagramme de dispersion. Tracez les valeurs prédites par rapport aux valeurs vraies.
figure scatter(YTest,anglesTest,"+") xlabel("Predicted Value") ylabel("True Value") hold on plot([-60 60], [-60 60],"r--")
Voir aussi
trainnet
| trainingOptions
| dlnetwork