Fast implementation of max-plus matrix multiplication
7 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
I am trying to implement a fast max-plus algebra multiplication of square matrices in MATLAB. The max-plus multiplication ⊗ between two matrices
returns the matrix
such that
.
returns the matrix
such that A naive implementation of the max-plus multiplication in MATLAB is:
function C = mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
for i = 1:n
for j = 1:n
for k = 1:n
C(i,j) = max(C(i,j), A(i,k) + B(k,j));
end
end
end
However, this is not particularly fast. The fastest implementation I could come up with is:
function C = fast_mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
A = transpose(A);
for i = 1:n
C(i,:) = max(A(:,i) + B);
end
which takes about 0.18 seconds to multiply two 100×100 matrices 100 times on my computer. Since performing the same test using a "standard" multiplication takes about 0.004 seconds (45 times faster), I was wondering if there were a way to speed up the code and obtain more comparable timings using MATLAB.
If this were not possible, would it be worth to try to create a package similar to LAPACK for max-plus algebra to get faster performances? Or is this the maximum achievable speed for reasons like "the CPU is inherently slower to calculate the maximum than the product of two numbers"?
4 commentaires
Jan
le 13 Avr 2021
The naive max() implementation contains a branch. This slows down the processing in general, because it impedes the pipelining in the CPU, which decides for one of the branches using heuristics. In case of max() about half of the predictions are false.
The matrix multiplication is performed in highly optimized library. Even the naive implementation in Matlab profits from Matlab's JIT acceleration, which does not handle max() with the same efficiency.
Do you have a C compiler installed?
Réponse acceptée
Jan
le 13 Avr 2021
Modifié(e) : Jan
le 15 Avr 2021
A C-mex version:
[EDITED: Check of type double is added]
// mp_prod_mex.c
// C = mp_prod_mex(A, B)
// INPUT: A, B: Real double [m x k] and [k x n] matrices.
// OUTPUT: Real double [m x n] matrix.
//
// Equivalent Matlab code:
// [m, p] = size(A); n = size(B, 2);
// C = -inf(m, n);
// for i = 1:m
// for j = 1:n
// for k = 1:p
// C(i,j) = max(C(i,j), A(i,k) + B(k,j));
// end
// end
// end
//
// COMPILE:
// mex -O -R2018a mp_prod_mex.c
//
// Jan, Heidelberg, (C) 2021, License: CC BY-SA 3.0
// Handling of rectangular matrices: Bruno Luong
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *Ak, *B, *Bj, *Bk, *C, mInf = -mxGetInf();
register double s, t;
mwSize m, n, p, i, j, k;
// Get inputs:
A = mxGetDoubles(prhs[0]);
B = mxGetDoubles(prhs[1]);
m = mxGetM(prhs[0]); // size(A, 1)
n = mxGetN(prhs[1]); // size(B, 2)
p = mxGetN(prhs[0]); // size(A, 2) and size(B, 1)
if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) ||
mxIsComplex(prhs[0]) || mxIsComplex(prhs[1])) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Inputs must be real double matrices.");
}
if (mxGetM(prhs[1]) != p) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Size of A and B do not match.");
}
// Create output:
plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
C = mxGetDoubles(plhs[0]);
// Calculate result:
for (j = 0; j < n; j++) { // Loop over B at 1st for linear access to C
Bj = B + j * n;
for (i = 0; i < m; i++) {
Ak = A + i;
Bk = Bj;
t = mInf;
for (k = 0; k < p; k++) {
s = *Ak + *Bk++;
Ak += m;
if (s > t) {
t = s;
}
}
*C++ = t;
}
}
return;
}
Timings, 100x100 input, 100 iterations:
% Matlab R2018b, i5 mobile:
Elapsed time is 0.365667 seconds. % 1st version from question
Elapsed time is 0.253912 seconds. % 2nd version from question
Elapsed time is 0.337319 seconds. % Bruno's version
Elapsed time is 0.077329 seconds. % C mex
Transposing A for a contiguous access was no significant advantage.
12 commentaires
Seth Younger
le 14 Sep 2021
@Jan I would love to connect with you and learn more about how this implementation works. Are you slicing the matrix up? I am confused by the (A,1) and (A,2).
Plus de réponses (1)
Bruno Luong
le 13 Avr 2021
Modifié(e) : Bruno Luong
le 13 Avr 2021
function C = mp_prod(A,B)
m=size(A,1);
n=size(B,2);
AA=reshape(A,m,1,[]);
BB=reshape(B.',1,n,[]);
C=max(AA+BB,[],3);
tic/toc result
>> A=rand(100);
>> B=rand(100);
>> tic; C = mp_prod(A,B); toc
Elapsed time is 0.005206 seconds.
8 commentaires
Jan
le 14 Avr 2021
Modifié(e) : Jan
le 14 Avr 2021
@Davide Zorzenon: Which machine and Matlab version are you using?
The different timings can mean, that Davide's setup is more efficient for the loop, or Bruno's setup is more efficient for the vectorized solution.
Voir également
Catégories
En savoir plus sur Parallel Computing Toolbox dans Help Center et File Exchange
Produits
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!
