Vous suivez désormais cette question
- Les mises à jour seront visibles dans votre flux de contenu suivi.
- Selon vos préférences en matière de communication il est possible que vous receviez des e-mails.
I'm using VIT transformer in my code. How to convert the output of 1D layer of VIT into 2D with format SSCB?
8 commentaires
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
Réponses (2)
1 commentaire
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
Voir également
Catégories
Tags
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!Une erreur s'est produite
Impossible de terminer l’action en raison de modifications de la page. Rechargez la page pour voir sa mise à jour.
Sélectionner un site web
Choisissez un site web pour accéder au contenu traduit dans votre langue (lorsqu'il est disponible) et voir les événements et les offres locales. D’après votre position, nous vous recommandons de sélectionner la région suivante : .
Vous pouvez également sélectionner un site web dans la liste suivante :
Comment optimiser les performances du site
Pour optimiser les performances du site, sélectionnez la région Chine (en chinois ou en anglais). Les sites de MathWorks pour les autres pays ne sont pas optimisés pour les visites provenant de votre région.
Amériques
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
Asie-Pacifique
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)