Vectorizing bsxfun
Afficher commentaires plus anciens
Hi,
I have two matrices (which are really lists of vectors) and would like a matrix of the pair-wise squared distances between all of them. The following code does what I want, but I'm curious if there's any way to vectorize this.
Thank you
rX=rand(nTrain,numDim);
rXClass=rand(nClass,numDim);
dists=zeros(nTrain,nClass);
for ii=1:nTrain
thisX=rX(ii,:);
dists(ii,:)=sum(bsxfun(@minus,thisX,rXClass).^2,2)/D;
end
Réponses (3)
the cyclist
le 16 Août 2011
rX2 = permute(rX,[1 3 2]);
rXClass2 = permute(rXClass,[3 1 2]);
dists = sum(bsxfun(@minus,rX2,rXClass2).^2,3)/D;
Sean de Wolski
le 16 Août 2011
dists2 = squeeze(sum(bsxfun(@minus,rX,reshape(rXClass',[1 numDim nClass])).^2,2))/D;
With all three sizes equaling 150, I have your elementary for-loop running the fastest:
nTrain = 150;
numDim = 150;
nClass = 150;
D = 1;
rX=rand(nTrain,numDim);
rXClass=rand(nClass,numDim);
t1 = 0;
t2 = 0;
t3 = 0;
for jj = 1:50
tic
dists=zeros(nTrain,nClass);
for ii=1:nTrain
thisX=rX(ii,:);
dists(ii,:)=sum(bsxfun(@minus,thisX,rXClass).^2,2)/D;
end
t1 = t1+toc;
tic
dists2 = squeeze(sum(bsxfun(@minus,rX,reshape(rXClass',[1 numDim nClass])).^2,2))/D;
t2 = t2+toc;
tic
rX2 = permute(rX,[1 3 2]);
rXClass2 = permute(rXClass,[3 1 2]);
dists3 = sum(bsxfun(@minus,rX2,rXClass2).^2,3)/D;
t3 =t3+toc;
end
isequal(dists,dists2,dists3)
[t1 t2 t3]
ans =
1
ans =
3.1505 4.0336 4.0368
4 commentaires
Sean de Wolski
le 16 Août 2011
The biggest time sink is actually the .^2. Removing that doubles the speed :(
Brendan
le 16 Août 2011
Sean de Wolski
le 16 Août 2011
Nope, probably not. The only hope I can think of for that one is maybe with James' mtimesx on File Exchange.
the cyclist
le 17 Août 2011
If you prefer the vectorized code, they can both be sped up a fair amount by pulling apart the one-liner calculation of "dists" into separate lines for the bsxfun call, the squaring, and the sum
Andrei Bobrov
le 16 Août 2011
my small contribution:
[a b] = meshgrid(1:nTrain,1:nClass);
dists = reshape(sum((rX(a(:),:) - rXClass(b(:),:)).^2,2),[],nTrain)'/D
Catégories
En savoir plus sur Matrix Indexing dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!