How does selfAttentionLayer work,implementing validation with brief code?

12 vues (au cours des 30 derniers jours)
How does selfAttentionLayer work in detail every step of the way, can you simply reproduce its working process based on the paper formula? Thus verifying the selfAttentionLayer it's correctness and consisency.
official description:
A self-attention layer computes single-head or multihead self-attention of its input.
The layer:
  1. Computes the queries, keys, and values from the input
  2. Computes the scaled dot-product attention across heads using the queries, keys, and values
  3. Merges the results from the heads
  4. Performs a linear transformation on the merged result

Réponse acceptée

cui,xingxing
cui,xingxing le 11 Jan 2024
Modifié(e) : cui,xingxing le 27 Avr 2024
Here I have provided myself a simple code workflow with only 2 dimensions, "CT", to illustrate how each step works.
Note that each variable followed by a comment has a dimension representation.
%% 验证selfAttentionLayer操作计算与自己手算一致性!
XTrain = dlarray(rand(10,20));% CT
numClasses = 4;
numHeads = 6;
queryDims = 48; % N1=48
layers = [inputLayer(size(XTrain),"CT");
selfAttentionLayer(numHeads,queryDims,NumValueChannels=12,OutputSize=15,Name="sa");
layerNormalizationLayer;
fullyConnectedLayer(numClasses);
softmaxLayer];
net = dlnetwork(layers);
% analyzeNetwork(net)
XTrain = dlarray(XTrain,"CT");
[act1,act2] = predict(net,XTrain,Outputs=["input","sa"]);
act1 = extractdata(act1);% CT
act2 = extractdata(act2);% CT
% layer params
layerSA = net.Layers(2);
QWeights = layerSA.QueryWeights; % N1*C
KWeights = layerSA.KeyWeights;% N1*C
VWeights = layerSA.ValueWeights;% N2*C
outputW = layerSA.OutputWeights;% N3*N2
Qbias = layerSA.QueryBias; % N1*1
Kbias = layerSA.KeyBias;% N1*1
Vbias = layerSA.ValueBias; % N2*1
outputB = layerSA.OutputBias;% N3*1
% step1
q = QWeights*act1+Qbias; % N1*T
k = KWeights*act1+Kbias;% N1*T
v = VWeights*act1+Vbias;% N2*T
% step2,multiple heads
numChannelsQPerHeads = size(q,1)/numHeads;% 1*1
numChannelsVPerHeads = size(v,1)/numHeads;% 1*1
attentionM = cell(1,numHeads);
for i = 1:numHeads
idxQRange = numChannelsQPerHeads*(i-1)+1:numChannelsQPerHeads*i;
idxVRange = numChannelsVPerHeads*(i-1)+1:numChannelsVPerHeads*i;
qi = q(idxQRange,:);% diQ*T
ki = k(idxQRange,:);% diQ*T
vi = v(idxVRange,:);% diV*T
% attention
dk = size(qi,1);% 1*1
attentionScores = mysoftmax(ki'*qi./sqrt(dk));% T*T, note matlab interal code use k'*q,not q'*k
attentionM{i} = vi*attentionScores; % diV*T
end
%step3,merge attentionM
attention = cat(1,attentionM{:}); % N2*T,N2 = diV*numHeads
%step4,output linear projection
act_ = outputW*attention+outputB;% N3*T
act2(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
act_(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
I have reproduced its working process in the simplest possible way,hope it help others.
function out = mysoftmax(X,dim)
arguments
X
dim = 1;
end
% X = X-max(X,[],dim); %防止X过大导致取exp的值为Inf
X = exp(X);
out = X./sum(X,dim);
end
-------------------------Off-topic interlude, 2024-------------------------------
I am currently looking for a job in the field of CV algorithm development, based in Shenzhen, Guangdong, China,or a remote support position. I would be very grateful if anyone is willing to offer me a job or make a recommendation. My preliminary resume can be found at: https://cuixing158.github.io/about/ . Thank you!
Email: cuixingxing150@gmail.com

Plus de réponses (0)

Catégories

En savoir plus sur Parallel and Cloud dans Help Center et File Exchange

Produits


Version

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by