How to implement Siamese network with the two subnetworks not share weights

4 views (last 30 days)
I was implementing a Siamese using matlab deep learning toolbox. It is easy to implement such a network when the two subnetworks of the Siamese network share weights follwoing this official demo. Now I want to implement a Siamese network with the two subnetworks not share weights. Is there any easy solutions? I know we can set two "dlnetwork", one for input image A and the other for input image B. But the problem is you need to load two subnetworks into GPU memory, which is unavailable when there is no enough memory.
Any good solutions is welcomed, thank you!

Answers (1)

Joss Knight
Joss Knight on 1 Sep 2022
You can try gathering the weights back from each network after you've used it, as in net = dlupdate(@gather,net). This should save some memory.
Joss Knight
Joss Knight on 10 Sep 2022
I'm imagining that you would do something like this, in your forwardSiamese function:
dlnet1 = dlupdate(@gpuArray,dlnet1);
F1 = forward(dlnet1,dlX1);
F1 = sigmoid(F1);
dlnet1 = dlupdate(@gather,dlnet1);
dlnet2 = dlupdate(@gpuArray,dlnet2);
% Pass the second image through the twin subnetwork
F2 = forward(dlnet2,dlX2);
F2 = sigmoid(F2);
dlnet1 = dlupdate(@gather,dlnet1);
For this to work you will need to ensure you always pass in your two networks, at the call to dlfeval as fully host-side networks, so something like
dlnet1 = dlupdate(@gather,dlnet1);
dlnet2 = dlupdate(@gather,dlnet2);
[gradientsSubnet,gradientsParams,loss] = dlfeval(@modelGradients,dlnet1,dlnet2,fcParams,dlX1,dlX2,pairLabels);
If you don't do this then it won't make any difference what you do inside modelGradients because MATLAB will hold onto the GPU copy from the calling code.
You should also remove the fcParams part of the code, since you seem to have deleted the fullyconnect operation and therefore it's wasting space.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by