How to combine multiple net in LSTM
3 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
I intend to train three sequences using LSTM, then combine them into one 'net' for prediction to speed up the training process. However, I'm facing difficulties in achieving this.
0 commentaires
Réponses (1)
Ben
le 9 Avr 2024
You can combine 3 separate LSTM-s into one network by adding them to a dlnetwork object and hooking up the outputs. Note that if the LSTM-s have OutputMode="sequence" then either you need all input sequences to have the same length, or have some layer(s) that can manage the data with different sequence lengths.
Here's an example with OutputMode="last"
inputSizes = [1,2,3];
outputSize = 4;
lstmHiddenSize = 5;
hiddenSize = 10;
sequenceLengths = [6,7,8];
x1 = dlarray(rand(inputSizes(1),sequenceLengths(1)),"CT");
x2 = dlarray(rand(inputSizes(2),sequenceLengths(2)),"CT");
x3 = dlarray(rand(inputSizes(3),sequenceLengths(3)),"CT");
layers = [
sequenceInputLayer(inputSizes(1))
lstmLayer(lstmHiddenSize,OutputMode="last")
concatenationLayer(1,3,Name="cat")
fullyConnectedLayer(hiddenSize)
reluLayer
fullyConnectedLayer(outputSize)];
net = dlnetwork(layers,Initialize=false);
net = addLayers(net,[sequenceInputLayer(inputSizes(2));lstmLayer(lstmHiddenSize,OutputMode="last",Name="lstm2")]);
net = addLayers(net,[sequenceInputLayer(inputSizes(3));lstmLayer(lstmHiddenSize,OutputMode="last",Name="lstm3")]);
net = connectLayers(net,"lstm2","cat/in2");
net = connectLayers(net,"lstm3","cat/in3");
net = initialize(net);
y = predict(net,x1,x2,x3)
2 commentaires
Narayan
le 26 Juin 2024
Mr.Ben, I have a query regarding your solution. It may be similar query.
I want to train the LSTM model seperately with two kinds of features and want to concatenate the LSTM layer output for fully connected layer for multi class classicifications. How can i do it during the traning. what should the xtrain format and ytrain_label format for training the model. Thank you in advance.
Voir également
Catégories
En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!