LSTM not outputting sequence
Afficher commentaires plus anciens
I am attempting to do sequence-to-sequence classification.
I haveNtime series of
observations each, and each observation collectsp features.
I build a
cell array XTrain. I set XTrain{i} to be the i-th
time series in my database.
I have two classes. I build a
cell array YTrain, where YTrain{i} is a
categorical vector telling me which class is at which time.
Now I build the following network:
inputSize = [p, 1, 1];
filterSize = [2 1];
numFilters = 20;
numHiddenUnits = 128;
numClasses = 2;
layers = [ ...
sequenceInputLayer(inputSize,'Name','input')
sequenceFoldingLayer('Name','fold')
convolution2dLayer(filterSize,numFilters,'Name','conv1')
reluLayer('Name','relu1')
convolution2dLayer(filterSize,numFilters,'Name','conv2')
reluLayer('Name','relu2')
flattenLayer('Name','flatten')
sequenceUnfoldingLayer('Name','unfold')
lstmLayer(numHiddenUnits,'OutputMode','sequence','Name','lstm')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','classification')];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');
maxEpochs = 1;
miniBatchSize = 2;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,lgraph,options);
However, if I then run:
YScores = predict(net,XTrain,'MiniBatchSize',1);
the output is a
cell array whose i-th entry is a
vector of class probabilities.
This is INCORRECT. It should be a
vector of class probabilities.
4 commentaires
John Malik
le 18 Déc 2019
Ridwan Alam
le 25 Déc 2019
Modifié(e) : Ridwan Alam
le 25 Déc 2019
Hey John, did you get a solution? Please share. Thanks!
Mohammad Sami
le 26 Déc 2019
Can you try putting the sequenceUnfoldingLayer before the flatten layer.
John Malik
le 26 Déc 2019
Modifié(e) : John Malik
le 26 Déc 2019
Réponses (0)
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!