Update BatchNorm Layer State in Siamese netwrok with custom loop for triplet and contrastive loss

10 vues (au cours des 30 derniers jours)
HI everyone, I'm trying to implement a siamese network for face verification. I'm using as a subnetwork a Resnet18 pretrained on my dataset and I'm trying to implement the triplet loss and contrstive loss. The major problem is due to the batch normalization layer in my subnetwork that need to be updated durine the training fase using
But searching on mathworks tutorials, i found the update using only the Crossentropy with one dlarray as input in the forward function that return the state
function [loss,gradients,state] = modelLoss(net,X,T)
[Y,state] = forward(net,X);
At the moment this is my training loop for Contrastive loss, there is another one similar for the triplet loss that thake 3 images at time
for iteration = 1:numIterations
[X1,X2,pairLabels] = GetSiameseBatch(IMGS, miniBatchSize);
% Convert mini-batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data
dlX1 = dlarray(single(X1),'SSCB');
dlX2 = dlarray(single(X2),'SSCB');
% clear X1 X2
% I load the pairs into the GPU memory
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlX1 = gpuArray(dlX1);
dlX2 = gpuArray(dlX2);
% Evaluate the model gradients and the generator state using
% dlfeval and the modelGradients functions
[loss,gradientsSubnet,state] = dlfeval(@modelLoss,dlnet,dlX1,dlX2,pairLabels);
dlnet.State = state;
% Update the Siamese subnetwork parameters. Scope: train the last fc
% for 128 dim features vector
[dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
adamupdate(dlnet,gradientsSubnet, ...
D = duration(0,0,toc(start),Format="hh:mm:ss");
lossValue = double(gather(extractdata(loss)));
% lossValue = double(loss);
title("Elapsed: " + string(D))
And the model loss is
function [loss,gradientsSubnet,state] = modelLoss(net,X1,X2,pairLabels)
% Pass the image pair through the network.
[F1,F2,state] = ForwardSiamese(net,X1,X2);
% Calculate binary cross-entropy loss.
margin = 1;
loss = ContrastiveLoss(F1,F2,pairLabels, margin);
% Calculate gradients of the loss with respect to the network learnable
% parameters.
gradientsSubnet = dlgradient(loss,net.Learnables);
But in the ForwardSiamese function I make the forward of the two dlarray X1 and X2 that contains the batch of pair images (i.e. in X1 there are 32 images, in X2 same, the first image in X1 is paired qith first image in X2 and so on) and compute the loss, but the state to update the batch norm layer where come from?
function [Y1,Y2,state] = ForwardSiamese(dlnet,dlX1,dlX2)
[Y1,state] = forward(dlnet,dlX1);
Y1 = sigmoid(Y1);
% Pass the second image through the twin subnetwork
Y2 = forward(dlnet,dlX2);
Y2 = sigmoid(Y2);
If i compute also [Y2,state] I have 2 states but which one should be used to update the batch norm TrainedMean and TrainedVariance?

Réponse acceptée

Joss Knight
Joss Knight le 6 Nov 2022
Modifié(e) : Joss Knight le 6 Nov 2022
Interesting question! The purpose of batch norm state is to collect statistics about typical inputs. In a normal Siamese workflow, both X1 and X2 are valid inputs, so you ought to be able to update the state with either result.
You could aggregate the state from both or even do an additional pass with both to compute the aggregated state, although this would come with extra performance cost. So
[~,dlnet.State] = forward(dlnet, cat(4,X1,X2));
You can do this after the call to dlfeval.
  4 commentaires
Joss Knight
Joss Knight le 14 Nov 2022
If you are only fine-tuning part of the network then you only need to update the state for the part you are modifying. It looks like in your case that part doesn't contain any batch normalization layers. In which case, don't update the State at all!
Filippo Vascellari
Filippo Vascellari le 14 Nov 2022
Great news for the bug, maybe i could speed up the process.
For the state i need the update because when i use the classification loss i have to train all the network, not only the layers after the pooling in the resnet18 as backbone

Connectez-vous pour commenter.

Plus de réponses (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by