Improve speed of linear interpolation in nested loops

9 vues (au cours des 30 derniers jours)
Alessandro D
Alessandro D le 24 Juil 2022
Commenté : Alessandro D le 30 Août 2022
I have to do 1-dimensional linear interpolation many times within 4 nested loops. My X-grid is sorted so I can use interp1q but the code is still slow for my purposes. I managed to do a simple vectorization that eliminates the innermost loop (so I have only 3 loops instead of 4) and it's much faster, but unfortunately still not fast enough for my problem. Any suggestions on how to improve speed? Thanks
I report below a MWE (please, bear in mind that in my real problem the grids are larger)
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
%b_gridp(:,k_c) = linspace(b_min,b_max,nb)'; %EDITED HERE
b_gridp(:,k_c) = linspace(b_min+rand,b_max-rand,nb)';
end
%% Slow, not vectorized code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.150729 seconds.
%% This is a faster but not fast enough!
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 0.717499 seconds.
err = max(abs(stay_arr-stay_arr2),[],'all')
err = 0
  5 commentaires
Alessandro D
Alessandro D le 25 Juil 2022
Modifié(e) : Alessandro D le 25 Juil 2022
Thanks for your suggestion! I think there are two different issues here: (1) make interpolation faster using the fact that the grid is equally spaced and (2) vectorize the code. I created a small function, called find_loc_equi, to generate bins and weights for interpolation when the grid is equally spaced. I attach below the new code. The speed improvement is however minimal. What I don't know is how to vectorize more the code: e.g. is it possible to remove the inner loop over k_c?
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
start code
%% Non-vectorized code with interp1q
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 14.475729 seconds.
%% Vectorized code with faster bin location
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital CAN I ELIMINATE THIS LOOP?
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
[left_loc,weights] = find_loc_equi(b_gridp(:,knext_ind),bnext); % scalars
dexit_inter = weights*pol_exit_bx(left_loc,:)+(1-weights)*pol_exit_bx(left_loc+1,:); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 1.136634 seconds.
err2 = max(abs(stay_arr-stay_arr2),[],'all')
err2 = 1.5765e-14
function [bins,weights] = find_loc_equi(b_grid,bnext)
% Find the left interpolating node and the corresponding weight
% Assumptions: b_grid is equally spaced
% Thanks to Walter Roberson
n = size(b_grid,1);
deltas = b_grid(2,:)-b_grid(1,:);
initial_values = b_grid(1,:);
frac = (bnext-initial_values)./deltas;
bins = floor(frac)+1;
bins = min(n-1,max(1,bins));
weight_right = mod(frac,1);
weights = 1-weight_right;
weights(bnext>=b_grid(n)) = 0 ;
weights(bnext<=b_grid(1)) = 1 ;
end
Alessandro D
Alessandro D le 30 Août 2022
The last two lines of the function find_loc_equi should be as:
weights(bnext>=b_grid(n,:)) = 0 ;
weights(bnext<=b_grid(1,:)) = 1 ;

Connectez-vous pour commenter.

Réponse acceptée

Bruno Luong
Bruno Luong le 25 Juil 2022
Modifié(e) : Bruno Luong le 25 Juil 2022
This seems to work
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
start code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.097401 seconds.
%% Full vectorized code
tic
bgridcommon = b_gridp(:,1);
Y = interp1(bgridcommon,(1:nb)',pol_debt); % nk x nb x nx
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = repmat(pol_kp_ind,[1 1 1 nx]);
K = reshape(K,size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
Elapsed time is 0.191299 seconds.
err = norm(stay_arr2(:)-stay_arr(:),Inf)
err = 8.6597e-15
  4 commentaires
Bruno Luong
Bruno Luong le 26 Juil 2022
Modifié(e) : Bruno Luong le 26 Juil 2022
Sorry forget my comment above about loop. The bin interval is not the first index. Here is the code corrected that works for variable bin vectors.
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min-rand(),b_max+rand(),nb)';
end
disp('start code')
start code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.051721 seconds.
%% Full vectorized code
tic
K = pol_kp_ind;
bminK = reshape(b_gridp(1,K),size(K));
bmaxK = reshape(b_gridp(nb,K),size(K));
Y = 1 + (nb-1) * (pol_debt - bminK) ./ (bmaxK-bminK);
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = reshape(repmat(K,[1 1 1 nx]),size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
Elapsed time is 0.166581 seconds.
err = norm(stay_arr2(:)-stay_arr(:),Inf)
err = 1.7542e-14
Alessandro D
Alessandro D le 27 Juil 2022
Thanks, this works and is significantly faster!

Connectez-vous pour commenter.

Plus de réponses (0)

Catégories

En savoir plus sur Performance and Memory dans Help Center et File Exchange

Produits


Version

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by