Effacer les filtres
Effacer les filtres

Custom training loop and parameters for semantic segmentation problem

63 vues (au cours des 30 derniers jours)
Manon
Manon le 5 Juil 2024 à 14:03
Réponse apportée : praguna manvi le 15 Juil 2024 à 13:09
Good day
I have been working on a semantic segmentation problem using the Deep Learning Toolbox, with the unetLayers method.
But I now need to tune some parameters that I trust are not accessible through the toolbox. I would like to:
  • Use a custom loss function, sum of a weighted focal loss and a dice loss.
  • Augment data over each epoch, and if possible, use a custom augmentation method, called RLR (Random Local Rotation).
  • Implement a learning rate based on the ‘One cycle’ method.
  • Track the training progress with as much details as possible.
Here is my base code; that worked for training the Unet:
[imds,pxds] = create_ds(folders);
indicesTest = 1:floor(length(imds.Files)/10);
testImds = subset(imds,indicesTest);
testPxds = subset(pxds,indicesTest);
dsTest = combine(testImds,testPxds);
indicesTraining = floor(length(imds.Files)/10):length(imds.Files);
trainingImds = subset(imds,indicesTraining);
trainingPxds = subset(pxds,indicesTraining);
dsTraining = combine(trainingImds,trainingPxds);
imageSize = [128 128 3];
numClasses = 2;
unetNetwork = unetLayers(imageSize,numClasses,EncoderDepth = 3);
opts = trainingOptions ("rmsprop", ...
InitialLearnRate = 1e-3, ...
MaxEpochs = 40, ...
MiniBatchSize = 32, ...
VerboseFrequency = 10, ...
Plots = "training-progress", ...
Shuffle="every-epoch", ...
ValidationData = dsTest, ...
ValidationFrequency=10, ...
OutputNetwork = "best-validation-loss" )
currentNet = trainNetwork(dsTraining,unetNetwork,opts)
create_ds is a function that returns two datastores. imds contains the RGB images and pxds contains categorical images with two classes, that are the masks to each image from imds.
The RLR function returns [image,label], the RGB and categorical images that have been geometrically modified (the dimensions and type are conserved).
Here is the function that returns the custom loss:
function loss = combinedLoss(Y, T, alpha, gamma)
epsilon = 1e-6;
p = sigmoid(Y);
lossPos = -alpha * (1 - p).^gamma .* T .* log(p + epsilon);
lossNeg = -(1 - alpha) * p.^gamma .* (1 - T) .* log(1 - p + epsilon);
weightedFocalLoss = mean(lossPos + lossNeg, 'all');
intersection = sum(Y .* T, 'all');
union = sum(Y, 'all') + sum(T, 'all');
diceCoeff = (2 * intersection + epsilon) / (union + epsilon);
diceLoss = 1 - diceCoeff;
loss = weightedFocalLoss + diceLoss;
end
I have reviewed the guides that explain how to create a custom training loop:
‘ Train network using custom training loop’ and ‘Monitor custom training loop progress’ -and some others- several times, but I still don’t get how to adapt the examples to my semantic segmentation problem, and with the correct behaviour for my input data.
I don’t think sharing the fragments of code I have tried to compose would be helpful, as they are very far from functional.
Any help on the matter, whether contributing to answer my question partially, or completely, would be greatly appreciated.
Have a nice week-end!

Réponses (1)

praguna manvi
praguna manvi le 15 Juil 2024 à 13:09
Hi, you can find many examples under this documentation for training under a custom for loop:
https://www.mathworks.com/help/deeplearning/examples.html?category=custom-training-loops , there is a working example of GANs & Style transfer which can help this use case.
It is possible to train a neural net with a custom learning schedular based on a specific implementation by passing its learning rate to “rmspropupdate” function (in this case).
And the network can step through a custom loss using its handler passed to “dlfeval” function.
For recording and visualizing training metrics/loss consider using “TrainingProgressMonitorand you could use a custom augmentation method like RLR to process images after loading from datastores at each epoch / specific iteration.
Hope this helps!

Catégories

En savoir plus sur Image Data Workflows dans Help Center et File Exchange

Produits


Version

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by