How can we use Nadam optimizer in place of sgdm in training deep learning networks

9 vues (au cours des 30 derniers jours)
kollikonda Ashok kumar
kollikonda Ashok kumar le 29 Mar 2023
Commenté : Amanjit Dulai le 25 Oct 2024 à 11:06
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', 32, ...
'MaxEpochs', 50, ...
"InitialLearnRate", 1e-5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Data, ...
'ValidationFrequency', 40, ...
"ExecutionEnvironment","gpu",...
'Plots','training-progress', ...
'Verbose',false);

Réponses (2)

Joss Knight
Joss Knight le 4 Avr 2023
You cannot do this using trainNetwork. You need to use a dlnetwork with a custom training loop so you can author your own update rule. Perhaps adam will work for you instead.

Amanjit Dulai
Amanjit Dulai le 25 Oct 2024 à 11:02
You can train with Nadam by defining a custom training loop. The function dlupdate can be used to define custom update rules for training. The rules for Nadam are shown below:
where the momentum is given by:
Below is an example of how to train a digit classification network using Nadam in a custom training loop:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
  1 commentaire
Amanjit Dulai
Amanjit Dulai le 25 Oct 2024 à 11:06
Also, if you want to use weight decay only on the weights, you can modify the example as shown below:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply weight regulatization
gradients(l2Indices,:) = dlupdate( @(g,w)g + l2RegularizationFactor*w, ...
gradients(l2Indices,:), net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
One thing to note is that with adaptive learning rules like Adam and Nadam, it has been found that it is often more effective to apply weight decay directly to the weights instead of the gradients. When applying this to Nadam, it results in the algorithm NadamW. Below is an example on how to use NadamW.
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply decoupled weight regulatization (NadamW)
net.Learnables(l2Indices,:) = dlupdate( @(w)w - learnRate*l2RegularizationFactor*w, ...
net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end

Connectez-vous pour commenter.

Catégories

En savoir plus sur Image Data Workflows dans Help Center et File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by