Effacer les filtres
Effacer les filtres

Export LSTM to ONNX with proper input information

7 vues (au cours des 30 derniers jours)
Brita Linnestad
Brita Linnestad le 7 Juil 2022
I have created a LSTM network and converted it to onnx using matlabs exportONNXNetwork. The onnx-network will be loaded in Java using OrtSession.
Layers: [5x1 nnet.cnn.layer.Layer]
layers =
Sequence Input Sequence input with 6 dimensions (numberOfFeatures)
LSTM LSTM with 50 hidden units
Fully connected 2 fully connected layers
Softmax softmax
Clasification Output crossentropyex
Sequence length is 24.
Using exportONNXNetwork(netlstm,filename), the only reported input is 'sequenceinput'.
How can i set up exportONNXNetwork so the onnx-model holds more/all input information needed when loading the model in Java?

Réponses (1)

Sivylla Paraskevopoulou
Sivylla Paraskevopoulou le 7 Juil 2022
I am not sure what you mean by "more/all input information". If you mean that you want a network that can be used for prediction, you must train the layer graph that you created and then export the trained network and not the layer graph.
  2 commentaires
Brita Linnestad
Brita Linnestad le 11 Juil 2022
I have trained the layer graph, and then exported the trained network.
When loading the trained network in Java using OrtSession, I get an ortsession runtimeerror :
Non-zero status code returned while running LSTM node. Name:'lstm' Status Message: Input initial_h must have shape {1,24,50}. Actual:{1,1,50}
How can I, before I export my model from Matlab, set initial_h or other information needed for OrtSession to run properly?
Sivylla Paraskevopoulou
Sivylla Paraskevopoulou le 12 Juil 2022
In MATLAB, if your input data is a vector sequence, the sequenceInputLayer expects the data in the format CSN, where C is the number of features or channels, S is the sequence length, and N is the number of observations. For an example on how to train a network with a vector sequence input, Train Network for Sequence Classification.
When you export the network to ONNX, the input tensor shape should be NSC. I am not sure what is happenning to the input when you convert from ONNX to ortSession.

Connectez-vous pour commenter.

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Produits


Version

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by