How to do the sum for 2 gradient objects in the deep learning toolbox?
Afficher commentaires plus anciens
Hi,
I have gradients1 and gradients2 which have exactly same structure but different numerical values. How can I do the sum? Current I tried gradients1+gradients2 but I got error.
Thanks!
My code:
rng(123); % seed
X_ori=[4,163,80;5,164,75]; % data; #(number) = 2; #(features) = 3;
X=permute(X_ori,[3,4,2,1]);
dlX = dlarray(X, 'SSCB');
Y_ori=[0, 0, 0, 1; 0, 1, 0, 0]; % data labels (i.e. one-hot vectors for 4 classes)
myModel = [
imageInputLayer([1 1 3],'Normalization','none','Name','in')
fullyConnectedLayer(7,'Name','Layer 1')
fullyConnectedLayer(4,'Name','Layer 2')];
MyLGraph = layerGraph(myModel);
myDLnet = dlnetwork(MyLGraph);
gradients1 = dlfeval(@modelGradients1, myDLnet, dlX, Y_ori);
gradients2 = dlfeval(@modelGradients2, myDLnet, dlX, Y_ori);
gradients_sum = gradients1+gradients2;
function [gradients1] = modelGradients1(myModel, modelInput, CorrectLabels)
CorrectLabels_transpose=transpose(CorrectLabels);
[modelOutput,state] = forward(myModel,modelInput);
loss = -31*sum(sum(CorrectLabels_transpose.*log(sigmoid(modelOutput/100))));
gradients1 = dlgradient(loss, myModel.Learnables);
end
function [gradients2] = modelGradients2(myModel, modelInput, CorrectLabels)
CorrectLabels_transpose=transpose(CorrectLabels);
[modelOutput,state] = forward(myModel,modelInput);
loss = -42*sum(sum(CorrectLabels_transpose.*log(sigmoid(modelOutput/100))));
gradients2 = dlgradient(loss, myModel.Learnables);
end
1 commentaire
Réponses (1)
Sourav Bairagya
le 10 Déc 2019
0 votes
As in this case, 'gradients1.Value' and 'gradients2.Value' both are cell arrays and each one contains another cell arrays as elements within it, hence, direct conversion of these two cell arrays into matrices using 'cell2mat' or direct addition of them using '+' operator is not possible. Hence, you have to access each element individually and add them.
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!