custom deep learning training loop: gradient computation using dlgradient
6 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
I'm trying to train a CNN with semi supervised learning but i can't evaluate the automatic gradient properly: in particular when i call the function dlgradient (with loss and net.Learnables as parameters) the program invokes other functions inside it and when it's the time of backwardTape (which is also the method that, using other nested functions, is able to compute the gradient) the program fails! it happens that backwardTape is just skipped by the program (actually it gives the output grad, but if i try to step in with the debugger, i can't and it jump to the next line of the code instead); the line is:
grad = backwardTape(tm,{y},{initialAdjoint},x,retainData,false);
in backwardPass.m of the deep learning toolbox. The output grad is just a vector of empty arrays
P.S. the dlnetwork i have created is based on alexnet using transfer learning.
part of the code of interest is:
loss = labeledLoss + unlabeledLoss; %this two statements are inside a training loop
gradients = dlfeval(@computeModelGradients,net,loss);
function gradients = computeModelGradients(network,loss)
gradients = dlgradient(loss,network.Learnables);
end
%where:
%studentNet is a 1x1 dlNetwork of 24 layers (of which 22 are from alexnet
%and the last 2 are a fully connected and a softmax)
%loss is 1x1 dlArray (which contain a double)
0 commentaires
Réponses (1)
Mohamed Marei
le 14 Mai 2021
I think I ran into a similar problem when attempting to train a ResNet-18-based model for transfer learning, too. I had to hard-code my evaluation and update step which was by no means straightforward.
In your case, you might want to compute the loss inside the call to dlfeval.
function [loss, gradients] = computeModelGradients(network, pred_labelled, tgts_labelled, pred_unlabelled)
labelled_loss = crossentropy(predictions_labelled, targets_labelled); % your loss definition here
unlabelled_loss = myfunction(pred_unlabelled); % your loss function for the unlabeled predictions
loss = labelled_loss + unlabelled_loss;
gradients = dlgradient(loss, network);
end
Voir également
Catégories
En savoir plus sur Image Data Workflows dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!