Setting inputs formats for nested Neural ODE
Afficher commentaires plus anciens
Hi all,
I am constructing a NN that nests a Neural ODE. The NN has two datasets as inputs: i) The initial values of internal states (InitialValue) that are used to feed the Neural ODE and, ii) The sequences (input_2) that are included in the NN after the Neural ODE. The input_2 and the outputs of the Neural ODE must be summed.
I tried creating the entries of the NN in cell and dlarray format. For the last one I also defined the dimensions 'CBT' and/or 'CB' according the structure of the dataserie, nevertheles the problem persists.
The error I got is the following
%% Generate data
data_1 = randn(1,1000);
data_2 = randn(1,1000);
tspan = 1:1:50;
InitialValue = data_1(:,1:end-length(tspan))';
indices = 1:length(InitialValue);
targets = arrayfun(@(i) data_1(:, i + tspan), indices, 'UniformOutput', false)';
input_2 = arrayfun(@(i) data_2(:, i + tspan), indices, 'UniformOutput', false)';
%% Create neuralnetwork
% NeuralODE layers
OdeLayer = [fullyConnectedLayer(5)
tanhLayer
fullyConnectedLayer(1)];
OdeNet = dlnetwork(OdeLayer,Initialize=false);
% Main layer
net = dlnetwork;
Layers =[featureInputLayer(1,'Name','Input 1')
neuralODELayer(OdeNet,tspan,"Name",'OdeLayer','GradientMode','adjoint')];
% add extra input for adition to NeuralODE output
net = addLayers(net, Layers);
net = addLayers(net, sequenceInputLayer(1,'Name','Input 2'));
net = addLayers(net, additionLayer(2,'Name','adition'));
% connect layers
net = connectLayers(net,'Input 2','adition/in2');
net = connectLayers(net,'OdeLayer','adition/in1');
%% Train Network
% gather inputs and targets
input_1_ds = arrayDatastore(InitialValue,"OutputType","same");
input_2_ds = arrayDatastore(input_2,"OutputType","same");
target_ds = arrayDatastore(targets,"OutputType","same");
cds = combine(input_1_ds, input_2_ds, target_ds);
opt = trainingOptions("adam");
% training
net = trainnet(cds,net,"l2loss",opt);
Thanks in advance for your feedback and comments.
Réponse acceptée
Plus de réponses (0)
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!