This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

resetState

Reset the state of a recurrent neural network

Syntax

updatedNet = resetState(recNet)

Description

example

updatedNet = resetState(recNet) resets the state of a recurrent neural network (for example, an LSTM network) to the initial state.

Examples

collapse all

Reset the network state between sequence predictions.

To reproduce the results in this example, set rng to 'default'.

rng('default')

Load JapaneseVowelsNet, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Load the test data.

load JapaneseVowelsTest

Classify a sequence and update the network state. For reproducibility, set rng to 'shuffle'.

rng('shuffle')
X = XTest{94};
[net,label] = classifyAndUpdateState(net,XTest{94});

Classify another sequence using the updated network.

X = XTest{1};
label = classify(net,X)
label = categorical
     7 

Compare the final prediction with the true label.

trueLabel = YTest(1)
trueLabel = categorical
     1 

The updated state of the network may have negatively influenced the classification. Reset the network state and predict on the sequence again.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical
     1 

Input Arguments

collapse all

Trained recurrent neural network, specified as a SeriesNetwork object. You can get a trained network by importing a pretrained network or by training your own network using the trainNetwork function.

recNet is a recurrent neural network. It must have at least one recurrent layer (for example, an LSTM network).

Output Arguments

collapse all

Updated network, returned as a SeriesNetwork object.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Introduced in R2017b