LSTM network stuck in local optima because of gradient tracking?

5 vues (au cours des 30 derniers jours)
Joonas Vierijärvi
Joonas Vierijärvi le 6 Oct 2023
Hi,
Im quite new to LSTM networks and ML in general. However I decided to start by trying to model an guitar amplifier with an LSTM network as described here: https://www.mdpi.com/2076-3417/10/3/766
I decided to create custom training function for this so I could be "sure" that im doing the model updates in similar manner. Now I cannot get my loss function to go below 0.5, even though I am using the same dataset as they did (in their case the loss function was ~0.05). Is there something clearly wrong in my update functions/data format im trying to input? So basically, each iteration I am putting 1 x N x 2048 into the network and here N is the size of the "mini batch".
I did some reading on the backpropagation through time and started thinking if that is affecting my model. As I understood it, In pytorch you can use hidden.detach to exclude hidden state somehow from the backpropagation and also you can zero gradients of the tensors that are being optimized. This would then take place between each mini batch. Can this be done in matlab or does matlab do that automatically when using dlgradient?
Update model:
%forward samples 1-1000 to initialize hidden state
[Y,state] = forward(net,X2(1,:,1:1000));
net.State = state;
%process the samples (each one length 2048)
for a = 1:10
% Evaluate the model loss and gradients using dlfeval and the
% modelLoss function
iteration = iteration + 1;
idb = (a-1)*sequencelength+1001:a*sequencelength+1000;
[loss,gradients,state] = dlfeval(@modelLoss,net, X2(1,:,idb),T2(1,:,idb),weightDecay);
net.State = state;
% Update the network parameters using the Adam optimizer.
[net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad,averageSqGrad,iteration,learnRate,0.9,0.999,1e-8);
net.State = state;
loss2 = loss2 + loss;
end
net = resetState(net);
Cost:
function [loss,gradients,state] = modelLoss(net,X,T,weightDecay)
% Forward data through network.
[Y,state] = forward(net,X);
esr = (sum(abs(T-Y).^2,3)./(sum(abs(T).^2,3)));
edc = (abs(((1/length(T))*sum(T-Y,3))).^2./((1/length(T))*sum(abs(T).^2,3)));
loss = sum(esr+edc)/length(Y(1,:,1));
%L2 regularization
allParams = net.Learnables(:,:).Value;
L = dlupdate(@(x) sum(x.^2,"all"),allParams);
L = sum(cat(1,L{:}));
loss = loss + weightDecay*0.5*L;
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss,net.Learnables);
end

Réponses (1)

Krishna
Krishna le 24 Nov 2023
Hello Joonas,
It's unlikely that an LSTM network gets stuck in a local optima solely due to gradient tracking (backpropagation algorithm that you mentioned). The issue of local optima is more related to the non-convex nature of the loss function rather than the specific type of gradient tracking used. Gradient tracking methods such as backpropagation are designed to help neural networks escape local optima by iteratively updating the weights to minimize the loss function. However, vanishing or exploding gradients in recurrent neural networks like LSTMs can cause training issues. Techniques like gradient clipping, careful weight initialization, or using variants of LSTMs (e.g., GRU) can help mitigate these problems. If you're facing convergence issues, it's worth exploring these techniques and adjusting the learning rate or optimizer to aid in escaping local optima.
Secondly there is no need to create a loop to train you data instead you can use “trainnet” where you can define your own loss function to avoid making any mistakes. Please go through this documentation to learn more about trainnet:
Also please make sure you have implemented the loss function correctly. Please go through this example to learn more about training network trough custom loops.
In MATLAB, the process of excluding certain tensors from backpropagation and zeroing gradients can be achieved through manual control, similar to the functionality provided by “hidden.detach” in PyTorch. To exclude out a tensor from backpropagation you can avoid it in the computation of gradients. Also, you can set the gradients to zero after computing gradients using “dlgradient.”
Hope this helps.

Catégories

En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange

Tags

Produits


Version

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by