Main Content

En savoir plus sur l’apprentissage par transfert

Cet exemple montre comment utiliser Deep Network Designer afin de préparer un réseau pour l’apprentissage par transfert.

L’apprentissage par transfert est communément utilisé dans les applications de Deep Learning. Vous pouvez utiliser un réseau préentraîné comme point de départ pour apprendre une nouvelle tâche. L’ajustement précis d’un réseau avec l’apprentissage par transfert est généralement beaucoup plus rapide et plus facile que l’apprentissage d’un réseau avec des pondérations initialisées de manière entièrement aléatoire. Vous pouvez transférer rapidement les caractéristiques apprises vers une nouvelle tâche avec un nombre réduit d’images d’apprentissage.

Charger les images

Dans l’espace de travail, extrayez le jeu de données MathWorks® Merch. Pour accéder à ces données, ouvrez l’exemple en tant que live script. Ce petit jeu de données contient 75 images de marchandises MathWorks appartenant à cinq classes différentes (casquette, cube, cartes à jouer, tournevis et lampe de poche).

folderName = "MerchData";
unzip("MerchData.zip",folderName);

Créez un datastore d’images. Un datastore d’images vous permet de stocker de grandes collections d’images, notamment des données qui ne peuvent pas être stockées en mémoire, et de lire efficacement ces images en batch pendant l’apprentissage d’un réseau de neurones. Spécifiez le dossier contenant les images extraites et indiquez que les noms des sous-dossiers correspondent aux étiquettes d’images.

imds = imageDatastore(folderName, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

Affichez quelques exemples d’images.

numImages = numel(imds.Labels);
idx = randperm(numImages,16);
I = imtile(imds,Frames=idx);
figure
imshow(I)

Récupérez les noms des classes et le nombre de classes.

classNames = categories(imds.Labels);
numClasses = numel(classNames);

Partitionnez les données entre des jeux d’apprentissage, de validation et de test. Utilisez 70 % des images pour l’apprentissage, 15 % pour la validation et 15 % pour les tests. La fonction splitEachLabel divise le datastore d’images en trois nouveaux datastores.

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");

Charger un réseau préentraîné

Pour adapter un réseau de neurones préentraîné à une nouvelle tâche, utilisez l’application Deep Network Designer.

deepNetworkDesigner

Sélectionnez SqueezeNet depuis la liste des réseaux préentraînés et cliquez sur Open.

Deep Network Designer affiche une vue dézoomée de l’ensemble du réseau.

Éditer le réseau pour l’apprentissage par transfert

Pour réentraîner SqueezeNet à classer de nouvelles images, modifiez la dernière couche de convolution 2D du réseau, conv10.

Dans le volet Designer, sélectionnez la couche conv10. Au bas du volet Properties, cliquez sur Unlock Layer. Dans la boîte de dialogue d’avertissement qui apparaît, cliquez sur Unlock Anyway. Les propriétés de la couche sont alors déverrouillées pour que vous puissiez les adapter à votre nouvelle tâche.

Définissez la propriété NumFilters au nouveau nombre de classes, 5 dans cet exemple. Modifiez les taux d’apprentissage de sorte que l’apprentissage soit plus rapide dans la nouvelle couche que dans les couches transférées en définissant WeightLearnRateFactor et BiasLearnRateFactor à 10.

Pour vérifier que le réseau est prêt pour l’apprentissage, cliquez sur Analyze. Deep Learning Network Analyzer n’indique aucune erreur ni aucun avertissement. Le réseau est donc prêt pour l’apprentissage. Pour exporter le réseau, cliquez sur Export. L’application enregistre le réseau dans la variable net_1.

Spécifier les options d’apprentissage

Spécifiez les options d’apprentissage. Le choix des options nécessite une analyse empirique. Pour explorer différentes configurations dans les options d’apprentissage au cours de vos expérimentations, vous pouvez utiliser l’application Experiment Manager.

options = trainingOptions("adam", ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=5, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Entraîner le réseau de neurones

Entraînez le réseau de neurones avec la fonction trainnet. Comme l’objectif est la classification, utilisez la perte d’entropie croisée.

net = trainnet(imdsTrain,net_1,"crossentropy",options);

Tester le réseau de neurones

Classez les images de test. Pour réaliser des prédictions avec plusieurs observations, utilisez la fonction minibatchpredict. Pour convertir les scores de prédiction en étiquettes, utilisez la fonction scores2label. La fonction minibatchpredict utilise automatiquement un GPU si disponible.

inputSize = net.Layers(1).InputSize(1:2);

augimdsTrain = augmentedImageDatastore(inputSize,imdsTest);

YTest = minibatchpredict(net,imdsTest);
YTest = scores2label(YTest,classNames);

Visualisez la précision de la classification dans un diagramme de confusion.

TTest = imdsTest.Labels;
figure
confusionchart(TTest,YTest);

Classer une nouvelle image

Classez une image de test. Lisez une image à partir d’un fichier JPEG, redimensionnez-la et convertissez-la en type de données single.

im = imread("MerchDataTest.jpg");

im = imresize(im,inputSize(1:2));
X = single(im);

Classez l’image. Pour faire une prédiction avec une seule observation, utilisez la fonction predict.

scores = predict(net,X);
[label,score] = scores2label(scores,classNames);

Affichez l’image avec l’étiquette prédite et le score correspondant.

figure
imshow(im)
title(string(label) + " (Score: " + gather(score) + ")")

Pour en savoir plus sur l’apprentissage par transfert et la façon d’améliorer la performance du réseau, veuillez consulter Retrain Neural Network to Classify New Images.

Références

[1] ImageNet. http://www.image-net.org.

[2] Iandola, Forrest N., Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, and Kurt Keutzer. "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5 MB model size." Preprint, submitted November 4, 2016. https://arxiv.org/abs/1602.07360.

[3] Iandola, Forrest N. "SqueezeNet." https://github.com/forresti/SqueezeNet.

Voir aussi

| | | |

Sujets associés