LSTM Layer input size.

14 vues (au cours des 30 derniers jours)
Alex
Alex le 17 Sep 2023
Commenté : Alex le 18 Sep 2023
Hi all, quick question.
I am learning regression LSTMs and in the following code I've got a TrainX variable which corresponds to [batchSize, sequenceLength, inputSize], which is [5950 x 14 x5].
clc; clear; close all;
% LSTM - Test 1 (AAPL)
% Read the CSV file into a table
data = readtable('AAPL.csv');
% Display the first few rows of the table
head(data)
% Extract data
Date = data.Date;
OpenP = data.Open;
HighP = data.High;
LowP = data.Low;
CloseP = data.Close;
CloseAdjP = data.AdjClose;
Volume = data.Volume;
%Scaling min-max
scaledOpenP = (OpenP - min(OpenP)) / (max(OpenP) - min(OpenP));
scaledHighP = (HighP - min(HighP)) / (max(HighP) - min(HighP));
scaledCloseP = (CloseP - min(CloseP)) / (max(CloseP) - min(CloseP));
scaledLowP = (LowP - min(LowP)) / (max(LowP) - min(LowP));
scaledCloseAdjP = (CloseAdjP - min(CloseAdjP)) / (max(CloseAdjP) - min(CloseAdjP));
scaledVolume = (Volume - min(Volume)) / (max(Volume) - min(Volume));
TrainX = [];
TrainY = [];
n_future = 1;
n_past = 14;
len = length(scaledCloseP);
for i = 1:len-n_past-n_future+1
% Extract sequences for TrainX
TrainX(end+1, :, 1) = scaledOpenP(i:i+n_past-1);
TrainX(end, :, 2) = scaledHighP(i:i+n_past-1);
%TrainX(end, :, 3) = scaledCloseP(i:i+n_past-1);
TrainX(end, :, 3) = scaledLowP(i:i+n_past-1);
TrainX(end, :, 4) = scaledCloseAdjP(i:i+n_past-1);
TrainX(end, :, 5) = scaledVolume(i:i+n_past-1);
% Next day close price for TrainY
TrainY(end+1, 1) = scaledCloseP(i+n_past+n_future-1);
end
%TrainX = permute(TrainX, [2, 1, 3]);
% Define the number of features
numFeatures = [14 5];
numHiddenUnits1 = 500; % First LSTM layer
numHiddenUnits2 = 200; % Second LSTM layer
dropoutRate = 0.2; % Dropout rate
numResponses = 1;
layers = [ ...
sequenceInputLayer(numFeatures, Normalization="zscore")
lstmLayer(numHiddenUnits1)
fullyConnectedLayer(numResponses)
regressionLayer];
analyzeNetwork(layers);
options = trainingOptions('adam', ...
'MaxEpochs', 10, ...
'MiniBatchSize', 16, ...
'Verbose', 1, ...
'Plots', 'training-progress');
% Train the network
net = trainNetwork(TrainX, TrainY, layers, options);
I have used standart input for LSTM (At least as I did it in Python).
I think the issues is numFeatures = [14 5]; How should I specify it?
When I run the code matlab gives out the following error
Error using trainNetwork
Invalid network.
Error in LSTM_1 (line 86)
net = trainNetwork(TrainX, TrainY, layers, options);
Caused by:
Layer 2: LSTM layers must have scalar input size, but input size (14×5) was received. Try using a
flatten layer before the LSTM layer.

Réponse acceptée

Ben
Ben le 18 Sep 2023
For sequenceInputLayer you don't need to specify the sequence length as a feature. So you would just need numFeatures = 5.
For batches of sequence data in trainNetwork you need each observation in the batch to be a cell, this applies to the input and output sequences - this is to allow for cases where each sequence might have a different length. Additionally each cell should contain an array with size "NumFeatures x SequenceLength".
Here's one way you could do that with data like yours
batchSize = 5950;
sequenceLength = 14;
numFeatures = 5;
% generate random data for example
TrainX = randn(batchSize,sequenceLength,numFeatures);
% permute batch x sequence x features -> features x sequence x batch
TrainX = permute(TrainX,[3,2,1]);
% convert to cell
TrainX = num2cell(TrainX,[1,2]);
% flatten into column vector of cell-s
TrainX = TrainX(:);
To demonstrate, here's how you would use trainNetwork to train an LSTM that attempts to just memorise the input:
layers = [
sequenceInputLayer(numFeatures)
lstmLayer(numFeatures)
regressionLayer];
opts = trainingOptions("sgdm");
trainNetwork(TrainX,TrainX,layers,opts)
  1 commentaire
Alex
Alex le 18 Sep 2023
That worked.
Thank you very much! :)

Connectez-vous pour commenter.

Plus de réponses (0)

Catégories

En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange

Produits


Version

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by