- Remove all layers after the last layerNormalizationLayer in the Vision Transformer backbone.
- Create a (formattable) custom layer and use the code in reshapePatchEmbedding function in the example below as its predict method. Alternatively, you can use a (formattable) functionLayer (using a handle to reshapePatchEmbedding).
- Connect the last layerNormalizationLayer in the Vision Transformer backbone to the custom/function layer, and connect the custom/function layer to the input of the decoder network.
How can I reassemble 'patch embedded' data back into original data structure in Vision Transformer on DeepNetworkDesigner?
6 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
Brian Park
le 24 Juin 2023
Réponse apportée : Samuel Somuyiwa
le 24 Juil 2023
I'm trying to finetune the Vision Transformer model from 'Computer Vision Toolbox Model for Vision Transformer Network' for instance segmentation.
So I need to reassemble the SCB data back into SSCB data using position information in positional vector but cannot find the layer block that makes it work.
How can I seperate the positional vector from 577(S) and make 576(S) into (SS) format data?
0 commentaires
Réponse acceptée
Samuel Somuyiwa
le 24 Juil 2023
Assuming you are using the Vision Transformer model as a backbone/encoder, you can obtain the output embedding from the last block (the last layerNormalizationLayer, before indexing1dLayer) and write a function to remove the output embedding corresponding to the class token, and reshape the resulting embedding to SSCB format. See example in the code below.
The example assumes that your decoder is a model function. If the decoder is a dlnetwork or layer graph, you can do the following instead:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(384,384,3),'SSCB');
% Obtain output embedding from last layerNormalizationLayer
out = forward(net, input, Outputs='encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to class token from input
out = in(2:end,:,:);
% Reshape resulting embedding to input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
0 commentaires
Plus de réponses (0)
Voir également
Catégories
En savoir plus sur Computer Vision with Simulink dans Help Center et File Exchange
Produits
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!