dlgradient of a subset of variables

7 vues (au cours des 30 derniers jours)
Matt J
Matt J le 9 Fév 2024
Modifié(e) : Matt J le 14 Fév 2024
I would like to apply dlgradient to a parametrized function, which would allow me in some cases to get the complete gradient of the function, but in others would take only the derivative with respect to the i-th variable. My attempt below fails. Is there a way that both succeeds and which would be efficient in the more general case when numel(x0) is large?
x0 = dlarray( [1; 2] );
i=[];
y_grad=dlfeval(@(x)oneDeriv(x,i), x0) %works
y_grad =
2×1 dlarray 1 2
i=2;
dy_dxi=dlfeval(@(x)oneDeriv(x,i), x0) %fails
dy_dxi =
1×1 dlarray 0
function out = oneDeriv(x,i)
y=sum(x.^2)/2;
if isempty(i) %take complete gradient
out=dlgradient(y,x);
else %take gradient only w.r.t x(i)
out=dlgradient(y,x(i));
end
end

Réponse acceptée

Matt J
Matt J le 13 Fév 2024
Modifié(e) : Matt J le 14 Fév 2024
This seems to work:
X0 = dlarray( [1; 2; 3; 4]*10 );
subset=[2,3];
gradTotal = getGradient(X0) %total gradient
gradTotal =
4×1 dlarray 10 20 30 40
gradSubset = getGradient(X0,2:3) %gradient on subset of x
gradSubset =
2×1 dlarray 20 30
function grad = getGradient(Xall,subset)
if nargin<2, subset=':'; end
grad = dlfeval( @(xsub)theFunction(X,xsub,subset), Xall(subset));
end
function grad = theFunction(Xall,xsub,subset)
Xall(subset)=xsub;
y=sum(Xall.^2)/2;
grad=dlgradient(y,xsub);
end

Plus de réponses (1)

Ben
Ben le 13 Fév 2024
This is a subtle part of the dlarray autodiff system, the line dlgradient(y,x(i)) returns 0 because it sees the operation x -> x(i), and only knows that y depends on x, not x(i).
You can work around this by computing all the derivatives and do indexing afterwards:
function out = oneDeriv(x,i)
y = sum(x.^2)/2;
out = dlgradient(y,x);
if ~isempty(i)
out = out(i);
end
end
  1 commentaire
Matt J
Matt J le 13 Fév 2024
Modifié(e) : Matt J le 13 Fév 2024
Thanks, Ben. So the answer is no?

Connectez-vous pour commenter.

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by