Effficient Computation of Matrix Gradient
2 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
Shreyas Bharadwaj
le 9 Avr 2024
Commenté : Shreyas Bharadwaj
le 9 Avr 2024
Hi,
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
2 commentaires
Bruno Luong
le 9 Avr 2024
Modifié(e) : Bruno Luong
le 9 Avr 2024
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
Réponse acceptée
Bruno Luong
le 9 Avr 2024
The best
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
% Method 3
tic
C = w .* AXB;
gradX = A' * C * B';
t2=toc
err = norm(gradX(:)-gradX_1(:),'inf') / norm(gradX_1(:))
fprintf('New code version 3 is %g faster\n', t1/t2)
Plus de réponses (1)
Bruno Luong
le 9 Avr 2024
I propose this, and time testing for N = 200;
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
gradX = zeros(N,N);
tic
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
gradX(p,q) = C * AB';
end
end
t2=toc
fprintf('New code version 1 is %g faster\n', t1/t2)
Voir également
Catégories
En savoir plus sur Logical dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!