Visualiser les activations d’un réseau LSTM
Cet exemple indique comment examiner et visualiser les caractéristiques apprises par les réseaux LSTM en extrayant les activations.
Chargez le réseau préentraîné. JapaneseVowelsNet
est un réseau LSTM préentraîné sur le jeu de données des voyelles japonaises comme décrit dans [1] et [2]. Il a été entraîné sur les séquences triées par longueur de séquence avec une taille de mini-batch de 27.
load JapaneseVowelsNet
Affichez l’architecture du réseau.
net.Layers
ans = 4×1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax
Chargez les données de test.
load JapaneseVowelsTestData
Visualisez les premières séries temporelles dans un graphique. Chaque ligne correspond à une caractéristique.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")
Pour chaque pas de temps des séquences, obtenez les activations générées par la couche LSTM (couche 2) pour ce pas de temps et mettez à jour l'état du réseau.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm"); net.State = state; end features = features';
Visualisez les 10 premières unités cachées au moyen d'une carte thermique.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
La carte thermique indique l'intensité de l'activation de chaque unité cachée et met en évidence l'évolution des activations à travers le temps.
Références
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Voir aussi
trainnet
| trainingOptions
| dlnetwork
| predict
| forward
| lstmLayer
| bilstmLayer
| sequenceInputLayer