機械学習の入力エラーについて

2 vues (au cours des 30 derniers jours)
Yuuki
Yuuki le 23 Nov 2020
Commenté : Yuuki le 30 Nov 2020
LSTMの学習方法について質問です.
最下部に示したコードを実行したとき,「予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。」が表示されうまく学習できません.
入力データは1タイムステップに t-2, t-1, t のデータが含まれており,それに対応する出力データは t+1 のデータとなっています.
ここで学習に用いるデータを
net = trainNetwork(XTrain_C, YTrain_C, layers, options);
のように,Cのみを用いるようにすると上手く実行できるのですが,元コードのようにA,C,Dの3つの時系列データを学習させたモデルを作成しようとするとエラー文が表示されてしまいます.
下記のコードをどう修正すれば実行可能になりますでしょうか?
clear all
close all
%% Make dataset
A = zeros(1,100);
B = zeros(1,100);
C = zeros(1,100);
D = zeros(1,100);
% A
for i = 1:100
if i <= 40
A(:,i) = i / 40;
elseif (41 <= i) && (i <= 45)
A(:,i) = 1 - ((i - 40) / 5);
elseif 46 <= i
A(:,i) = 0;
end
end
% B
for i = 1:100
if i <= 60
B(:,i) = i / 60;
elseif (61 <= i) && (i <= 65)
B(:,i) = 1 - ((i - 60) / 5);
elseif 66 <= i
B(:,i) = 0;
end
end
% C
for i = 1:100
if i <= 80
C(:,i) = i / 80;
elseif (81 <= i) && (i <= 85)
C(:,i) = 1 - ((i - 80) / 5);
elseif 86 <= i
C(:,i) = 0;
end
end
% D
for i = 1:100
if i <= 40
D(:,i) = i / 20;
elseif (21 <= i) && (i <= 25)
D(:,i) = 1 - ((i - 20) / 5);
elseif 26 <= i
D(:,i) = 0;
end
end
%% Plot
plot(1:100, A(1,:),'LineWidth',2);hold on
plot(1:100, B(1,:),'LineWidth',2);hold on
plot(1:100, C(1,:),'LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('A','B','C','Location','northwest')
grid on
%% Preparing for ML
% A
for i = 1:97
XTrain_A{1,i} = A(:,i:i+2).';
YTrain_A{1,i} = A(:,i+3);
end
% C
for i = 1:97
XTrain_C{1,i} = C(:,i:i+2).';
YTrain_C{1,i} = C(:,i+3);
end
% D
for i = 1:97
XTrain_D{1,i} = D(:,i:i+2).';
YTrain_D{1,i} = D(:,i+3);
end
% Input
XTrain{1,1} = XTrain_D;
XTrain{2,1} = XTrain_A;
XTrain{3,1} = XTrain_C;
YTrain{1,1} = YTrain_D;
YTrain{2,1} = YTrain_A;
YTrain{3,1} = YTrain_C;
%% TrainNetwork
numFeatures = 3;
numResponses = 1;
numHiddenUnits = 300;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(20)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.0001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',50, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain, YTrain, layers, options);
%% Test
Result = zeros(1,100);
Result(:,1:3) = B(1,1:3);
for i = 1:97
[net,Result(1,i+1)] = predictAndUpdateState(net, Result(:,i:i+2).');
end
%% Plot result
plot(1:100, B(1,:),'k','LineWidth',2);hold on
plot(1:100, Result(1,:),'r','LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('B','Predection','Location','northwest')
grid on
  7 commentaires
Naoya
Naoya le 30 Nov 2020
はい、その理解となります。 t-2, t-1 を含めずにまずはお試し頂ければと思います。
Yuuki
Yuuki le 30 Nov 2020
Naoya様
ご返信ありがとうございます.
何度か5×Sで試したもののあまり精度が出ず,他論文でt-5~tのデータを入力としt+1を出力する例を見たため,同様の方法で精度が上がらないかと思い上記の質問を設けた次第です.
もう少し他の方法で改善を試みようと思います.

Connectez-vous pour commenter.

Réponses (0)

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Produits


Version

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!