Effacer les filtres
Effacer les filtres

multiple input to a pre-trained model

3 vues (au cours des 30 derniers jours)
Rayan Matlob
Rayan Matlob le 5 Juil 2022
Modifié(e) : Rayan Matlob le 6 Juil 2022
I have three classes folders (Good, Moderate and Severe)
Each class folder of them has (5 subolders) which are (Original images, Red, Blue, Green, HUE, Value),
where (Red, Blue, Green, HUE,Value) are subolders contain images after applying filters on the (Original images folder).
I am using a pre-trained model (resnet50 or any other model you suggest), all images in all the folders are numbered in the same sequence (each subfolder contains images from 1 to 200).
How to train the model by taking each single image from subfolder(Original images),and to apply it in parralel with the images from the other subfoldere (Red, Blue, Green, HUe,Value) to the input of the model.
Note: for the validation, i need to use only the (original_images folder) and the model should fetch the other images from the other subfolders
Next is the matlab code, thanks:
  1 commentaire
Rayan Matlob
Rayan Matlob le 6 Juil 2022
Modifié(e) : Rayan Matlob le 6 Juil 2022
imds = imageDatastore('C:\Users\Rayan\Desktop\9_8_balance_data\R_9_1_GSM_3', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.77,'randomized');
numTrainImages = numel(imdsTrain.Labels);
net = resnet50;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
edit(fullfile(matlabroot,'examples','nnet','main','findLayersToReplace.m'))
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1,numClasses, ...
'Name','new_conv', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
layers = lgraph.Layers;
connections = lgraph.Connections;
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain)
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
miniBatchSize=10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',60, ...
'InitialLearnRate',0.00065, ...
'Shuffle','every-epoch', ...
'ValidationFrequency',valFrequency, ...
'ValidationData',augimdsValidation, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options);

Connectez-vous pour commenter.

Réponses (0)

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Produits


Version

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by