Error custom training loop: Value to differentiate must be a traced dlarray scalar.
Afficher commentaires plus anciens
Is it possible to include a Blackbox and still use Automatic Differentiation in MATLAB?
I am trying to do the following.
1) I have 3 input features which are x,y and z locations computed using a custom function (getcondvects_n_k). M such examples. xyz is a dlarray of shape 3-by-M
xyz=dlarray(flip(getcondvects_n_k([], 3, val_vectors),2),'BC');
2) A NN will compute a value of either 0 or 1 for each example
layers = [
featureInputLayer(3,"Name","elementCenterLocations")
fullyConnectedLayer(20,"Name","fclayer1")
batchNormalizationLayer("Name","batchnorm1")
leakyReluLayer(0.3,"Name","leakyrelu1")
fullyConnectedLayer(1,"Name","fclayer2")
sigmoidLayer("Name","sigmoid")];
lgraph = layerGraph(layers);
dlnet=dlnetwork(lgraph);
3) Forward Pass
r=forward(dlnet,xyz);
4) Blackbox
The output from the NN is fed to a seperate function. It is like a custom loss function and computes Loss and derivative of wrt r i.e. dl_dr which is nx-by-ny-by-nz matrix.
R=reshape(double(extractdata(r)),nx, ny,nz);
[loss, dl_dr]=black_box(R, other_inputs);
5) Backward Pass
So I want to use dl_dr to update the weights of the NN
grad = dlgradient(dlarray(dl_dr(:)),dlnet.Learnables,'RetainData',true);
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,loop,learnRate);
6) The Forward Pass, Blackbox and Backward Pass will be in a custom training loop.
I'm getting the error when dlgradient is called. Can you please suggest changes if any? There are no known outputs Y and the Blackbox has many steps that involves matrix inversion. The inputs to the Blackbox cannot be a dlarray.
But the equation relating loss and r is straight forward and hence it's derivative is also straight forward.
Réponse acceptée
Plus de réponses (0)
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!