How the number of parameters is calculated if multihead self attention layer is used in a CNN model?

35 vues (au cours des 30 derniers jours)
Hana Ahmed
Hana Ahmed le 28 Août 2025 à 18:37
Commenté : Umar le 30 Août 2025 à 9:40
I have run the example in the following link in two cases:
Case 1: NumHeads = 4, NumKeyChannels = 784 Case 2: NumHeads = 8, NumKeyChannels = 392 Note that:
4x784 = 8x392 = 3136 (size of input feature vector to the attention layer). I have calculated the number of model parameters in the two cases and I got the following: 9.8 M for the first case, and 4.9 M for the second case.
I expected the number of learnable parameters to be the same. However, MATLAB reports different parameter counts.
My understanding from research papers is that the total parameters should not scale with how input is split across heads. The number of parameters should be the same as long as the input feature vector is the same, and the product of the number of heads by the size of each head (number of channels) is equal to the input size.
Why does MATLAB’s selfAttentionLayer produce different parameter counts for these two configurations? Am I misinterpreting how the layer is implemented in this toolbox?
  3 commentaires
Hana Ahmed
Hana Ahmed le 29 Août 2025 à 3:44
I would be very grateful if you could suggest or provide a correct MATLAB implementation — ideally as a custom layer — that follows the standard multi-head attention equations.
Umar
Umar le 30 Août 2025 à 9:40

Hi @Hana Ahmed,

Thanks for your follow-up! I think writing the multi-head attention mechanism from scratch would be a great way to get the transparency and control you're looking for. It will also help you understand the underlying principles better.

Here’s a quick skeleton of the pseudo code to guide your implementation:

Skeleton of Pseudo Code:

function Y = multiHeadAttention(X, numHeads, keyChannels)
  % X: Input matrix [batchSize, inputDim]
  % numHeads: Number of attention heads
  % keyChannels: Dimensionality per head
    [batchSize, inputDim] = size(X);
    d_k = keyChannels; % Dimension per head
    % Define weights for Q, K, V, and output projection
    W_Q = randn(inputDim, numHeads * d_k);
    W_K = randn(inputDim, numHeads * d_k);
    W_V = randn(inputDim, numHeads * d_k);
    W_O = randn(numHeads * d_k, inputDim);
    % Compute Q, K, V
    Q = X * W_Q;
    K = X * W_K;
    V = X * W_V;
    % Reshape for multiple heads
    Q = reshape(Q, batchSize, numHeads, d_k);
    K = reshape(K, batchSize, numHeads, d_k);
    V = reshape(V, batchSize, numHeads, d_k);
    % Compute attention for each head
    attentionOutput = zeros(batchSize, numHeads, d_k);
    for i = 1:numHeads
        % Compute scaled dot-product attention for each head
        attentionScores = Q(:, i, :) * K(:, i, :)' / sqrt(d_k);
        attentionWeights = softmax(attentionScores, 2);
        attentionOutput(:, i, :) = attentionWeights * V(:, i, :);
    end
    % Concatenate heads and project to output
    attentionOutput = reshape(attentionOutput, batchSize, numHeads * d_k);
    Y = attentionOutput * W_O;
  end

I suggest you try implementing this yourself in MATLAB, following the structure above. This will give you a hands-on understanding of how the attention mechanism works.

If you run into any issues or get stuck, feel free to reach out, and I’d be happy to help debug.

Good luck with the implementation!

Connectez-vous pour commenter.

Réponses (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by