dlgradients returning zeros, loss dependant of other already trained network

12 vues (au cours des 30 derniers jours)
Marcos
Marcos le 31 Jan 2024
Commenté : Marcos le 12 Fév 2024
Hello,
The loss I use depends on another already trained shallow network. Then I use dlgradient and the calculated loss to find the gradients, but this only returns zeros...
Here's the loss,gradients function:
function [loss,gradients] = modelLoss(dlnet,X,net2)
% Forward data through network.
[Y] = forward(dlnet,X);
% Data through trained network2.
X_2 = [extractdata(X(1:2,:));extractdata(Y)];
X_pred = net2(X_2);
% Convert to dlarray.
dlX_pred = dlarray(X_pred,'CB');
% Calculate loss.
loss = mean(mean((dlX_pred - X(3:end,:)).^2));
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss,dlnet.Learnables);
end
And here is the use of the dlfeval:
[loss,gradients] = dlfeval(@modelLoss,dlnet,dLXMiniBatch,net2);
Any idea on what's missing?
Thanks!

Réponse acceptée

Avadhoot
Avadhoot le 12 Fév 2024
Hi Marcos,
I see that you are calculating gradients for your custom loss function for your model. As mentioned in the example, you have correctly included all the calculations involved in the loss computation inside the loss function, including the pretrained network. Still there could be a few issues which might cause the gradient function to return 0. Below are the probable fixes for this issue:
  1. Make sure that you are not updating the pretrained network weights in calculating the loss, as this can cause a problem in the computation.
  2. The gradients might not get calculated if the pretrained network includes a non-differentiable step.
  3. The dlnet.Learnables should contain all the parameters that you want to update. Check if all the relevant parameters are included in it.
  4. X and Y need to be "dlarray" objects for the gradient function to work. If not, please convert them to dlarray objects before passing them to the loss function.
As you mentioned that you are using a shallow network, the vanishing gradient problem should not trouble you. Please check on all the above factors to determine the cause of the problem.
I hope this helps.
  1 commentaire
Marcos
Marcos le 12 Fév 2024
Hi Avadhoot,
Thanks for your response.
I managed to solve my problem by also using a dlnetwork approach for net2 instead of using feedforwardnet. Y was a dlarray and I tried converting the X_pred to dlarray in the loss function. I guess that was the problem then.
Anyway, having two dlnetwork objects (a pretrained one and the one being trained) isn't a problem, so I'm using that and it seems to be working :)
Thanks again!

Connectez-vous pour commenter.

Plus de réponses (0)

Produits


Version

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by