Contenu principal

Visualiser les activations d’un réseau LSTM

Cet exemple indique comment examiner et visualiser les caractéristiques apprises par les réseaux LSTM en extrayant les activations.

Chargez le réseau préentraîné. JapaneseVowelsNet est un réseau LSTM préentraîné sur le jeu de données des voyelles japonaises comme décrit dans [1] et [2]. Il a été entraîné sur les séquences triées par longueur de séquence avec une taille de mini-batch de 27.

load JapaneseVowelsNet

Affichez l’architecture du réseau.

net.Layers
ans = 
  4×1 Layer array with layers:

     1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
     2   'lstm'            LSTM              LSTM with 100 hidden units
     3   'fc'              Fully Connected   9 fully connected layer
     4   'softmax'         Softmax           softmax

Chargez les données de test.

load JapaneseVowelsTestData

Visualisez les premières séries temporelles dans un graphique. Chaque ligne correspond à une caractéristique.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")

Figure contains an axes object. The axes object with title Test Observation 1, xlabel Time Step contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Pour chaque pas de temps des séquences, obtenez les activations générées par la couche LSTM (couche 2) pour ce pas de temps et mettez à jour l'état du réseau.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm");
    net.State = state;
end
features = features';

Visualisez les 10 premières unités cachées au moyen d'une carte thermique.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

La carte thermique indique l'intensité de l'activation de chaque unité cachée et met en évidence l'évolution des activations à travers le temps.

Références

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Voir aussi

| | | | | | |

Rubriques