Initializing LSTM which is imported using ONNX

12 vues (au cours des 30 derniers jours)
Andreas
Andreas le 17 Juil 2024
Réponse apportée : Andreas le 23 Juil 2024
Hi,
I am training an LSTM for RL using Ray in Python. I would like to export this model using ONNX and afterwards import it in Matlab. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.
Minimum working example:
Python code to train LSTM:
import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
% Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)
% train for 2 episodes
for i in range(2):
result = algo.train()
% get policiy
ppo_policy = algo.get_policy()
% batch size
B=1
% initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)
% apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))
% save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)
Code in Matlab:
% Import model from where I saved it
net = importNetworkFromONNX('path/to/onnx-model');
% input shapes
obs_size = [1,4];
state_size=[2,1,256];
seq_lens_size=[1];
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
Error message:
I appreciate any help!
Best,
Andreas
  2 commentaires
Nilesh
Nilesh le 17 Juil 2024
Modifié(e) : Nilesh le 17 Juil 2024
Hello Andreas,
Have you tried asking your issue with ChatGPT.
Andreas
Andreas le 17 Juil 2024
Yes, but without success so far.

Connectez-vous pour commenter.

Réponses (3)

Joss Knight
Joss Knight le 18 Juil 2024
This code is suspect
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
I think your network has a single input, so you need to pass a single input to initialize (along with the network), basically just some example input exactly like you want to pass to predict. I think you have two channels and a sequence length of 256? And one of your dimensions is Time so you need a T dimension. And I don't think you have any spatial dimensions, so no S labels. So you need something like
exampleInput = dlarray(rand(2,1,256),'CBT');
net = initialize(net, exampleInput);
Or if you prefer, a permutation of that like
exampleInput = dlarray(rand(256,2,1),'TCB');
net = initialize(net, exampleInput);
If this doesn't work, try running analyzeNetwork(net) to see where your inputs are and we can work out what to expect.
  1 commentaire
Andreas
Andreas le 23 Juil 2024
Hi,
the network does not have a single input. I managed to solve the issue, see below for my response. Thank you, for your help anyway!

Connectez-vous pour commenter.


Kaustab Pal
Kaustab Pal le 19 Juil 2024
It seems you want to determine the input dimension of your imported network. You can easily find this information using the analyzeNetwork function. This function provides an interactive visualization of the network architecture and detailed information, including:
  • Layer types
  • Sizes and formats of layer learnable parameters
  • States and activations
  • Total number of learnable parameters
The activation size of the topmost layer will give you the input dimension.
Additionally, when creating dlarray objects in MATLAB, you need to specify the format, which must follow this order:
  • "S" (Spatial)
  • "C" (Channel)
  • "B" (Batch)
  • "T" (Time)
  • "U" (Unspecified)
For more details, you can refer to the following links:
  1. analyzeNetwork Documentation: https://www.mathworks.com/help/deeplearning/ref/analyzenetwork.html#mw_bdd24886-fa03-4540-a111-391541a0a684
  2. dlarray Documentation:: https://www.mathworks.com/help/deeplearning/ref/dlarray.html#d126e57736:~:text=When%20you%20create%20a%20formatted%20dlarray%20object%2C%20the%20software%20automatically%20permutes%20the%20dimensions%20such%20that%20the%20format%20has%20dimensions%20in%20this%20order%3A
Hope this helps.
  1 commentaire
Joss Knight
Joss Knight le 19 Juil 2024
Just FYI, the formats do not have to follow that order.

Connectez-vous pour commenter.


Andreas
Andreas le 23 Juil 2024
Helly everyone,
thank you for your help. Unfortunately, I had to work around this issue but I could solve it in the end. I believe the reason for matlab struggling is that within Ray's Rllib the models contain a lot of complicated overhead. In particular the inputs to the network are lists/dicts etc which undergo quite some reformatting which seemed to cause some issues. In the end, what I did is extract the actual torch models which are relevant from the trained Rllib object and joined them in a new torch.nn.Module object. For this object it worked out just fine using torch.onnx.export.
Thank you all for your help.
Best, Andreas

Catégories

En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange

Produits


Version

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by