Initial State Dynamical System LSTM Network

5 vues (au cours des 30 derniers jours)
Michael Hesse
Michael Hesse le 18 Nov 2020
Commenté : Michael Hesse le 19 Nov 2020
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures);
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
net = resetState(net);
xpred = x0;
for i = 1 : length(t)-1
[net, xpred(:, i+1)] = predictAndUpdateState(net, xpred(:, i));
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');
This is an example code where I want to predict the trajectory of a pendulum via LSTM neural network. How can I provide the initial state x0 into the network? If you look at the figure the second state directly jumps from x0 to the state [0, 0]'. Why does this happen?
  1 commentaire
Michael Hesse
Michael Hesse le 19 Nov 2020
Here is a possible workaround. Instead of learning the next state, one can learn the difference to the next state.
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end) - x(:, 1:end-1);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures, 'Normalization', 'zscore');
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
xpred = x0;
for i = 1 : length(t)-1
[net, dxpred] = predictAndUpdateState(net, xpred(:, i));
xpred(:, i+1) = xpred(:, i) + dxpred;
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');

Connectez-vous pour commenter.

Réponses (0)

Catégories

En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange

Produits


Version

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by