Fast implementation of max-plus matrix multiplication

I am trying to implement a fast max-plus algebra multiplication of square matrices in MATLAB. The max-plus multiplication between two matrices returns the matrix such that .
A naive implementation of the max-plus multiplication in MATLAB is:
function C = mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
for i = 1:n
for j = 1:n
for k = 1:n
C(i,j) = max(C(i,j), A(i,k) + B(k,j));
end
end
end
However, this is not particularly fast. The fastest implementation I could come up with is:
function C = fast_mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
A = transpose(A);
for i = 1:n
C(i,:) = max(A(:,i) + B);
end
which takes about 0.18 seconds to multiply two 100×100 matrices 100 times on my computer. Since performing the same test using a "standard" multiplication takes about 0.004 seconds (45 times faster), I was wondering if there were a way to speed up the code and obtain more comparable timings using MATLAB.
If this were not possible, would it be worth to try to create a package similar to LAPACK for max-plus algebra to get faster performances? Or is this the maximum achievable speed for reasons like "the CPU is inherently slower to calculate the maximum than the product of two numbers"?

4 commentaires

Jan
Jan le 13 Avr 2021
Modifié(e) : Jan le 13 Avr 2021
Of course finding the maximum of a vector with 100 elements is slower than the multiplication of two numbers: You have to access 100 elements in the memory and perform 99 comparisons.
-inf(n) might be faster than -inf * ones(n), but this is not the bottleneck of the code. In the fast version C=zeros(n) would be sufficient.
Davide Zorzenon
Davide Zorzenon le 13 Avr 2021
Modifié(e) : Davide Zorzenon le 13 Avr 2021
Maybe I should have written: is computing inherently slower than computing ? Here the number of operations is the same (both a naive implementation of the max-plus multiplication and a naive implementation of the standard multiplication between two matrices require operations)
Jan
Jan le 13 Avr 2021
The naive max() implementation contains a branch. This slows down the processing in general, because it impedes the pipelining in the CPU, which decides for one of the branches using heuristics. In case of max() about half of the predictions are false.
The matrix multiplication is performed in highly optimized library. Even the naive implementation in Matlab profits from Matlab's JIT acceleration, which does not handle max() with the same efficiency.
Do you have a C compiler installed?
Thank you for the explanation! Yes, I have it.

Connectez-vous pour commenter.

 Réponse acceptée

Jan
Jan le 13 Avr 2021
Modifié(e) : Jan le 15 Avr 2021
A C-mex version:
[EDITED: Check of type double is added]
// mp_prod_mex.c
// C = mp_prod_mex(A, B)
// INPUT: A, B: Real double [m x k] and [k x n] matrices.
// OUTPUT: Real double [m x n] matrix.
//
// Equivalent Matlab code:
// [m, p] = size(A); n = size(B, 2);
// C = -inf(m, n);
// for i = 1:m
// for j = 1:n
// for k = 1:p
// C(i,j) = max(C(i,j), A(i,k) + B(k,j));
// end
// end
// end
//
// COMPILE:
// mex -O -R2018a mp_prod_mex.c
//
// Jan, Heidelberg, (C) 2021, License: CC BY-SA 3.0
// Handling of rectangular matrices: Bruno Luong
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *Ak, *B, *Bj, *Bk, *C, mInf = -mxGetInf();
register double s, t;
mwSize m, n, p, i, j, k;
// Get inputs:
A = mxGetDoubles(prhs[0]);
B = mxGetDoubles(prhs[1]);
m = mxGetM(prhs[0]); // size(A, 1)
n = mxGetN(prhs[1]); // size(B, 2)
p = mxGetN(prhs[0]); // size(A, 2) and size(B, 1)
if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) ||
mxIsComplex(prhs[0]) || mxIsComplex(prhs[1])) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Inputs must be real double matrices.");
}
if (mxGetM(prhs[1]) != p) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Size of A and B do not match.");
}
// Create output:
plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
C = mxGetDoubles(plhs[0]);
// Calculate result:
for (j = 0; j < n; j++) { // Loop over B at 1st for linear access to C
Bj = B + j * n;
for (i = 0; i < m; i++) {
Ak = A + i;
Bk = Bj;
t = mInf;
for (k = 0; k < p; k++) {
s = *Ak + *Bk++;
Ak += m;
if (s > t) {
t = s;
}
}
*C++ = t;
}
}
return;
}
Timings, 100x100 input, 100 iterations:
% Matlab R2018b, i5 mobile:
Elapsed time is 0.365667 seconds. % 1st version from question
Elapsed time is 0.253912 seconds. % 2nd version from question
Elapsed time is 0.337319 seconds. % Bruno's version
Elapsed time is 0.077329 seconds. % C mex
Transposing A for a contiguous access was no significant advantage.

12 commentaires

Really impressive, thank you! Here the timings on my computer:
% Matlab R2019a, i7:
Elapsed time is 0.244774 seconds. % 1st version from question
Elapsed time is 0.172351 seconds. % 2nd version from question
Elapsed time is 0.242824 seconds. % Bruno's version
Elapsed time is 0.090291 seconds. % C mex
Bruno Luong
Bruno Luong le 14 Avr 2021
Modifié(e) : Bruno Luong le 14 Avr 2021
Warning: Jan's code won't work on non-square matrices and especially would break if
size(A,2) == 0
I would think the correct max/+ algebra would return
-inf(m,n)
in such case.
Jan
Jan le 14 Avr 2021
Modifié(e) : Jan le 14 Avr 2021
@Bruno Luong: Thanks for the hint. But Davide has asked for square matrices explicitly. For non-square inputs, my C code throws an error message. If the inputs are empty, the output is empty, too - as expected. In addition, this is exactly what your suggested -inf(m,n) replies for m==n and n==0:
mp_prod_mex([], []) % []
-inf(0, 0) % []
Therefore I do not understand, why you call this a "warning".
I've added a check, if the inputs have the type double.
@Bruno Luong: in fact, I was particularly interested in the case with square matrices. Moreover, Jan's code seems to be easily generalizable to non-square matrices.
Bruno Luong
Bruno Luong le 14 Avr 2021
Modifié(e) : Bruno Luong le 14 Avr 2021
Yes, it's a warning for the reader who are interested in generic max/+ matrices.
Jan's code can be modified, excepted the special treatment of for the first element should be removed and instead start the k-loop from 0. Then the C matrix should be initialized with -Inf, and not 0, since -Inf is the neutral element of max operator round+).
Jan
Jan le 14 Avr 2021
Modifié(e) : Jan le 15 Avr 2021
Are you sure, that -Inf is replied even for empty input matrices? Neither -inf(0,n) nor -inf(m, 0) would contain any -inf value, but both contain no values. So what is the size of the output for empty inputs? Is this the wanted Matlab code:
[m, p] = size(A);
n = size(B, 2);
C = -inf(m, n);
for i = 1:m
for j = 1:n
for k = 1:p
C(i,j) = max(C(i,j), A(i,k) + B(k,j));
end
end
end
Then
mp_prod(zeros(2,0), zeros(0,3))
% -Inf -Inf -Inf
% -Inf -Inf -Inf
If so, let me mention that your code throws an error for empty inputs and does not reply Inf's also:
bl_mp_prod(zeros(2,0), zeros(0,3))
% Error using reshape
% To RESHAPE the number of elements must not change.
% Error in kgd>bl_mp_prod (line 82)
% C=reshape(C,[m n])
The C-mex code to accept non-square inputs and to reply Inf's:
[EDITED] Version, which handles rectangular ipnuts, moved to the answer.
Bruno Luong
Bruno Luong le 14 Avr 2021
Modifié(e) : Bruno Luong le 14 Avr 2021
Yes my code won't handle well either.
if
A = zeros(3,0);
B = zeros(0,2);
then
mp_prod(A,B)
should return
-inf(3,2)
to be theoretically conformed with (non-square) max/+ algebra (similar to A*B, but the "0" element is replaced by neutral for max which is -Inf)
I'm allow my self to modify Jan code to extend it for non-rectangular matrices
// mp_gprod_mex.c
// C = mp_prod_mex(A, B)
// INPUT: A: Real double [m x p] matrices.
// B: Real double [p x n] matrices.
// OUTPUT: C: Real double [m x n] matrix.
//
// COMPILE:
// mex -O -R2018a mp_gprod_mex.c
//
// Jan, Heidelberg, (C) 2021, License: CC BY-SA 3.0
// Modified by Bruno Luong
#include "mex.h"
#include "math.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *Ak, *B, *Bj, *Bk, *C;
register double s, t;
mwSize m, n, p, i, j, k;
// Get inputs:
A = mxGetDoubles(prhs[0]);
B = mxGetDoubles(prhs[1]);
m = mxGetM(prhs[0]);
p = mxGetN(prhs[0]);
n = mxGetN(prhs[1]);
// Check validity of inputs:
if (mxIsComplex(prhs[0]) || mxIsComplex(prhs[1]) ||
mxGetM(prhs[1]) != p ||
!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) ) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_gprod_mex: Inputs must be real double matrices");
}
// Create output:
plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
C = mxGetDoubles(plhs[0]);
// Calculate result:
for (j = 0; j < n; j++) {
Bj = B + j * p;
for (i = 0; i < m; i++) {
Ak = A + i;
Bk = Bj;
t = -INFINITY; // -1/0;
for (k = 0; k < p; k++) { // Max of following elements
s = *Ak + *Bk++;
Ak += m;
if (s > t) {
t = s;
}
}
*C++ = t;
}
}
return;
}
Bruno Luong
Bruno Luong le 14 Avr 2021
Modifié(e) : Bruno Luong le 14 Avr 2021
Jan, are you sure about the increment step:
Ak += p;
Jan
Jan le 14 Avr 2021
Thanks Bruno. You are right: "Ak += m;" is wanted. I've resused the test cases with square matrices only.
Davide Zorzenon
Davide Zorzenon le 14 Avr 2021
Modifié(e) : Davide Zorzenon le 14 Avr 2021
@Bruno Luong: thanks for pointing out the case with matrices of dimensions and . I had never thought about that but to be consistent the result of the max-plus multiplication shoud be -inf(m,n) as you wrote.
@Jan I would love to connect with you and learn more about how this implementation works. Are you slicing the matrix up? I am confused by the (A,1) and (A,2).

Connectez-vous pour commenter.

Plus de réponses (1)

Bruno Luong
Bruno Luong le 13 Avr 2021
Modifié(e) : Bruno Luong le 13 Avr 2021
function C = mp_prod(A,B)
m=size(A,1);
n=size(B,2);
AA=reshape(A,m,1,[]);
BB=reshape(B.',1,n,[]);
C=max(AA+BB,[],3);
tic/toc result
>> A=rand(100);
>> B=rand(100);
>> tic; C = mp_prod(A,B); toc
Elapsed time is 0.005206 seconds.

8 commentaires

Thank you Bruno. However, your implementation does not seem to be faster than fast_mp_prod(). Indeed, the timings I reported in the question were refered to this test:
tic; for i = 1:100; mp_prod(A,B); end; toc
(there are 100 multiplication of two 100×100 matrices). Your code takes about 0.25 seconds to perform the above test on my computer, which is worse than fast_mp_prod().
Bruno Luong
Bruno Luong le 13 Avr 2021
Modifié(e) : Bruno Luong le 13 Avr 2021
On my computer the later code (working major column) is slightly faster (about 10%).
>> A=rand(100);
>> B=rand(100);
>> tic; for i = 1:1000; fast_mp_prod(A,B); end; toc
Elapsed time is 3.127958 seconds.
>> tic; for i = 1:1000; bl_mp_prod(A,B); end; toc
Elapsed time is 2.723747 seconds.
>>
major column code
function C = bl_mp_prod(A,B)
m=size(A,1);
n=size(B,2);
A=reshape(A.',[],m,1);
B=reshape(B,[],1,n);
C=A+B;
C=max(C,[],1);
C=reshape(C,[m n]);
I see; I guess the actual speed dependends a lot on the exact architecture. This is what I get:
>> A = rand(100); B = rand(100);
>> tic; for i = 1:1000; fast_mp_prod(A,B); end; toc
Elapsed time is 1.630039 seconds.
>> tic; for i = 1:1000; bl_mp_prod(A,B); end; toc
Elapsed time is 2.110560 seconds.
Jan
Jan le 13 Avr 2021
Modifié(e) : Jan le 13 Avr 2021
Maybe you an Bruno use different Matlab versions?
Bruno Luong
Bruno Luong le 14 Avr 2021
Modifié(e) : Bruno Luong le 14 Avr 2021
Mine is R2021a, the hardware is pretty old 7 year old PC Windows 8.
You guessed right Jan: I am running the code on Matlab R2019a. Probably the hardware plays a big role here, too: my computer is a 2 years old Windows 10 with processor i7 (2.20GHz), 16GB RAM
Jan
Jan le 14 Avr 2021
Modifié(e) : Jan le 14 Avr 2021
@Davide Zorzenon: Which machine and Matlab version are you using?
The different timings can mean, that Davide's setup is more efficient for the loop, or Bruno's setup is more efficient for the vectorized solution.
@Jan: I use Matlab R2019a and my computer is a Windows 10 x64 intel core i7 (2.20GHz), with 16GB RAM and 6 cores.

Connectez-vous pour commenter.

Catégories

En savoir plus sur Parallel Computing Toolbox dans Centre d'aide et File Exchange

Produits

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by