Trainnetwork to Trainnet conversion
11 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
Hi there,
I was using Trainnetwork(https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_408bdd15-2d34-4c0d-ad91-bc83942f7493) function for my study. However, in 2024b trainnet function(https://www.mathworks.com/help/deeplearning/ref/trainnet.html#mw_ffa5eeae-b6e0-444e-a464-91e257cef95b) is slightly faster in computing. I try to convert my Trainnetwork function to trainnet but i can't managed. How can i convert it? My code is written below. Thank you.
%% Train network part
numClasses = numel(categories(trainImgs.Labels));
dropoutProb = 0.2;
layers = [...%my network layers in here.
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress',"MiniBatchSize",64, ...
'ValidationData',valImgs,"ExecutionEnvironment","gpu")
%% Training network
trainednet = trainNetwork(trainImgs,layers,options)
% trainednet = trainnet(trainImgs,layers,"crossentropy",options)
1 commentaire
Réponses (1)
Paras Gupta
le 15 Juil 2024
Hi Emre,
I understand that you are experiencing issues when transitioning from the 'trainNetwork' function to the 'trainnet' function in MATLAB.
From the code provided in the question, I assume that you are using the same network for both functions. However, the 'trainnet' function requires a slightly modified network architecture that does not include the output layer in the specified layer array. Instead of using an output layer, a loss function is specified using the 'lossFcn' argument.
The following example code illustrates the difference in the network architectures for both the functions:
%% Load and Preprocess Data
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% Split data into training and validation sets
[trainImgs, valImgs] = splitEachLabel(imds, 0.8, 'randomized');
%% Define Network Architectures
% Network for trainNetwork (includes output layer)
layersTrainNetwork = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% Network for trainnet (does not include output layer)
% Instead of the output layer, we specify the loss function in the trainnet syntax
layersTrainnet = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
dropoutLayer(0.2)
fullyConnectedLayer(10)];
%% Training Options
options = trainingOptions('adam', ...
'Plots','training-progress', ...
'MiniBatchSize',64, ...
'ValidationData',valImgs, ...
'ExecutionEnvironment','gpu');
%% Train Network using trainNetwork
trainednet = trainNetwork(trainImgs, layersTrainNetwork, options);
%% Train Network using trainnet
trainednet_ = trainnet(trainImgs, layersTrainnet, "crossentropy", options);
Please refer to the following documentation links for more information on the differences between 'trainNetwork' and 'trainnet' functions:
- 'trainNetwork' version history - https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html#mw_408bdd15-2d34-4c0d-ad91-bc83942f7493
- 'trainingOptions' version history - https://www.mathworks.com/help/deeplearning/ref/trainingoptions.html#mw_ef0175ee-4c8f-443d-87c8-9b531abf3133
Hope this helps resolve the issue.
0 commentaires
Voir également
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!