Custom deep learning network - gradient function using dlfeval
Afficher commentaires plus anciens
I want to create a custom deep learning training function, the output of which is an array Y. I have two inputs, the arrays X1 and X2. I want to find the gradient of Y with respect to X1 and X2.
This is my network:
layers1 = [
sequenceInputLayer(sizeInput,"Name","XTrain1")
fullyConnectedLayer(numHiddenDimension,"Name","fc_1")
softplusLayer('Name','s_1')];
layers2 = [
sequenceInputLayer(sizeInput,"Name","XTrain2")
fullyConnectedLayer(numHiddenDimension,"Name","fc_2")
softplusLayer('Name','s_2')];
lgraph = layerGraph(layers1);
lgraph = addLayers(lgraph,layers2); % connect layers -> 2 in, 1 out
add = additionLayer(2,'Name','add');
lgraph = addLayers(lgraph,add);
lgraph = connectLayers(lgraph,'s_1','add/in1');
lgraph = connectLayers(lgraph,'s_2','add/in2');
fc = fullyConnectedLayer(sizeInput,"Name","fc_3");
lgraph = addLayers(lgraph,fc);
lgraph = connectLayers(lgraph,'add','fc_3');
dlnet = dlnetwork(lgraph);
My
should become my output. Then every iteration, I do:
dlX1 = dlarray(X1,'CTB');
dlX2 = dlarray(X2,'CTB');% to differentiate: dlarray/dlgradient
for i = 1:sizeInput
[gradx1(i), gradx2(i), dlY] = dlfeval(@modelGradientsX,dlnet,dlX1(i),dlX2(i)); % here is where I get my error
end
and I call my function
, which is supposed to get the derivative of my output with respect to my inputs.
, which is supposed to get the derivative of my output with respect to my inputs.function [gradx1, gradx2, dlY] = modelGradientsX(dlnet,dlX1,dlX2)
dlY = forward(dlnet,dlX1,dlX2);
[gradx1, gradx2] = dlgradient(dlY,dlX1,dlX2);
end
And the error I get is: "Input data must be formatted dlarray objects". I have seen similar approaches in other examples (like this one: https://www.mathworks.com/matlabcentral/fileexchange/74760-image-classification-using-cnn-with-multi-input-cnn) so I don't understand - why is
not the correct type of data?
Réponse acceptée
Plus de réponses (1)
Iris Soa
le 27 Juil 2020
Catégories
En savoir plus sur Operations dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!