updatedNet = resetState(recNet) resets the state of a recurrent neural network (for example, an LSTM network) to the initial state.


Reset the network state between sequence predictions.

Load JapaneseVowelsNet, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Load the test data.

[XTest,YTest] = japaneseVowelsTestData;

Classify a sequence and update the network state. For reproducibility, set rng to 'shuffle'.

X = XTest{94};
[net,label] = classifyAndUpdateState(net,X);
label = categorical

Classify another sequence using the updated network.

X = XTest{1};
label = classify(net,X)
label = categorical

Compare the final prediction with the true label.

trueLabel = YTest(1)
trueLabel = categorical

The updated state of the network may have negatively influenced the classification. Reset the network state and predict on the sequence again.

net = resetState(net);
label = classify(net,XTest{1})
label = categorical

Input Arguments

Trained recurrent neural network, specified as a SeriesNetwork or a DAGNetwork object. You can get a trained network by importing a pretrained network or by training your own network using the trainNetwork function.

recNet is a recurrent neural network. It must have at least one recurrent layer (for example, an LSTM network). If the input network is not a recurrent network, then the function has no effect and returns the input network.

Output Arguments

Updated network. updatedNet is the same type of network as the input network.

If the input network is not a recurrent network, then the function has no effect and returns the input network.


