機械学習の入力エラーについて
2 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
LSTMの学習方法について質問です.
最下部に示したコードを実行したとき,「予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。」が表示されうまく学習できません.
入力データは1タイムステップに t-2, t-1, t のデータが含まれており,それに対応する出力データは t+1 のデータとなっています.
ここで学習に用いるデータを
net = trainNetwork(XTrain_C, YTrain_C, layers, options);
のように,Cのみを用いるようにすると上手く実行できるのですが,元コードのようにA,C,Dの3つの時系列データを学習させたモデルを作成しようとするとエラー文が表示されてしまいます.
下記のコードをどう修正すれば実行可能になりますでしょうか?
clear all
close all
%% Make dataset
A = zeros(1,100);
B = zeros(1,100);
C = zeros(1,100);
D = zeros(1,100);
% A
for i = 1:100
if i <= 40
A(:,i) = i / 40;
elseif (41 <= i) && (i <= 45)
A(:,i) = 1 - ((i - 40) / 5);
elseif 46 <= i
A(:,i) = 0;
end
end
% B
for i = 1:100
if i <= 60
B(:,i) = i / 60;
elseif (61 <= i) && (i <= 65)
B(:,i) = 1 - ((i - 60) / 5);
elseif 66 <= i
B(:,i) = 0;
end
end
% C
for i = 1:100
if i <= 80
C(:,i) = i / 80;
elseif (81 <= i) && (i <= 85)
C(:,i) = 1 - ((i - 80) / 5);
elseif 86 <= i
C(:,i) = 0;
end
end
% D
for i = 1:100
if i <= 40
D(:,i) = i / 20;
elseif (21 <= i) && (i <= 25)
D(:,i) = 1 - ((i - 20) / 5);
elseif 26 <= i
D(:,i) = 0;
end
end
%% Plot
plot(1:100, A(1,:),'LineWidth',2);hold on
plot(1:100, B(1,:),'LineWidth',2);hold on
plot(1:100, C(1,:),'LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('A','B','C','Location','northwest')
grid on
%% Preparing for ML
% A
for i = 1:97
XTrain_A{1,i} = A(:,i:i+2).';
YTrain_A{1,i} = A(:,i+3);
end
% C
for i = 1:97
XTrain_C{1,i} = C(:,i:i+2).';
YTrain_C{1,i} = C(:,i+3);
end
% D
for i = 1:97
XTrain_D{1,i} = D(:,i:i+2).';
YTrain_D{1,i} = D(:,i+3);
end
% Input
XTrain{1,1} = XTrain_D;
XTrain{2,1} = XTrain_A;
XTrain{3,1} = XTrain_C;
YTrain{1,1} = YTrain_D;
YTrain{2,1} = YTrain_A;
YTrain{3,1} = YTrain_C;
%% TrainNetwork
numFeatures = 3;
numResponses = 1;
numHiddenUnits = 300;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(20)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.0001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',50, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain, YTrain, layers, options);
%% Test
Result = zeros(1,100);
Result(:,1:3) = B(1,1:3);
for i = 1:97
[net,Result(1,i+1)] = predictAndUpdateState(net, Result(:,i:i+2).');
end
%% Plot result
plot(1:100, B(1,:),'k','LineWidth',2);hold on
plot(1:100, Result(1,:),'r','LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('B','Predection','Location','northwest')
grid on
7 commentaires
Réponses (0)
Voir également
Catégories
En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

