How to improve speed of calculating trace in a script?

3 vues (au cours des 30 derniers jours)
Xiaohan Du
Xiaohan Du le 3 Avr 2018
Commenté : Xiaohan Du le 7 Avr 2018
Hi all,
In my project I have to calculate the trace of some matrix products, I have the following script to demonstrate the purpose:
clear; clc;
% number of total tests.
nTest = 500;
% part 1. generate nTest*2 random matrices.
nd = 1000;
nt = 100;
dis = cell(nTest, 2);
dis = cellfun(@(v) rand(nd, nt), dis, 'un', 0);
% part 2. perform truncated-SVD on each matrix,
% only leave nRem singular vectors and values.
nRem = 50;
disSVD = cell(nTest, 2);
for isvd = 1:nTest
for jsvd = 1:2
[u, s, v] = svd(dis{isvd, jsvd}, 0);
disSVD{isvd, jsvd} = {u(:, 1:nRem), s(1:nRem, 1:nRem), v(:, 1:nRem)};
end
end
% part 3. for each SVD result, perform trace to obtain disTrans. disTrans is
% non-symmetric, thus jtr needs to start from 1.
disTrans = zeros(nTest);
for itr = 1:nTest
u1 = disSVD{itr, 1};
for jtr = 1:nTest
u2 = disSVD{jtr, 2};
disTrans(itr, jtr) = ...
trace(u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}');
end
end
I ran profile to find out which part is the slowest, it turns out it's calculating the trace in part 3 due to the large number. Unfortunately in my project the number of calculating trace is also very large. So any idea of how to improve the speed of calculating the trace? The profile is shown here:
Many thanks!
  8 commentaires
Steven Lord
Steven Lord le 3 Avr 2018
Note that built-in functions like svd, cellfun, rand, etc. don't show up in your Profiler report but they do take time. If you open the entry for testuiTujSortImprove I believe you'll see a line in the Children table where the time spent in built-in functions will be reported, and given how many times you called svd in particular I think that will account for the "missing" time.
Xiaohan Du
Xiaohan Du le 4 Avr 2018
I got it, opening the entry of testuiTujSortImprove shows:
It seems that the operations inside trace, i.e.
u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}'
cost the majority of time.

Connectez-vous pour commenter.

Réponse acceptée

Christine Tobler
Christine Tobler le 3 Avr 2018
You can make the trace operator work faster as follows: Currently, the input is two truncated SVDs, A1 = U1 * S1 * V1' and A2 = U2 * S2 * V2', and you are computing
trace(A1'*A2)
correct? After inserting A1 and A2, you can use the property of trace that trace(A*B) == trace(B*A) (note that trace(A*B*C) ~= trace(A*C*B), see wikipedia).
So this means that you can rearrange
trace(V1*S1'*U1'*U2*S2*V2') == trace( (V2'*V1) * S1 * (U1'*U2) * S2)
Make sure that the parentheses are set like this, and all other operations are on nRem-by-nRem matrices.
By the way, you can also rewrite trace(A'*B) as sum(sum(A.*B)), but I'm not sure if this will give you a speedup for this case.
  3 commentaires
Christine Tobler
Christine Tobler le 4 Avr 2018
Trace is a very quick operation, compared with the matrix multiplications, so the larger part of the speed-up is probably about those multiplications. The trace call itself should also become a bit faster, because it now acts on nRem-by-nRem matrices, instead of nTest-by-nTest matrices.
You can try taking the function calls apart:
M3 = u2{3}' * u1{3};
M1 = u1{1}' * u2{1};
M = M3 * u1{2} * M1 * u2{2};
trace(M);
This way, the profiler will tell you how much time is spent in each line, and you can see how much time the trace call takes.
Xiaohan Du
Xiaohan Du le 7 Avr 2018
Thanks Chris!

Connectez-vous pour commenter.

Plus de réponses (0)

Catégories

En savoir plus sur MATLAB dans Help Center et File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by