Animated-Gradient-Descent
Gradient Descent can be considered as one of the most important algorithms in machine learning and deep learning. It is widely used in training simple machine learning models to complex deep learning networks.
This matlab script can generate an animation gif which visualizes how gradient descent works in a 3D or contour plot. This can be very helpful for beginners of machine learning.
In addition, user can set the cost funciton, the learning rate (alpha), and starting point of gradient descent. This allows advanced users to demonstrate or investigate the impact of these factors in training a machine learning model.
Please refresh your browser to restart the animation gifs below.
Important features:
- Support both 3D surface plot and contour plot
- Allow user to enter the (cost) function to plot
- Allow user to adjust learning rate to understand the effect of it
- Allow user to set starting point of gradient descent, so as to investigate saddle point effect
- Allow user to set the maximum steps and stop threshold
- Allow user to set an array of leaning rates to compare
- Use either 2D contourf or 2D contour
- Generate animation GIF which can be inserted into webpage and presentation
Please refer to test.m for getting started examples.
agd = animateGraDes();
Need to set the (cost) function for gradient descent. This is required.
agd.funcStr='x^2+2*x*y+3*y^2+4*x+5*y+6';
This is the function used to generate the animation GIF of the contour plot above.
If no more optional parameters are needed to set, just call
agd.animate();
This starts the gradient descent. Once the local min is found, an annotation text box is shown in the figure. It includes the alpha value being used and the local min it found.
All these parameters are optional. If they are not set, the default values are used. Note these need to be set before calling agd.animate() to take effect.
This is the range for plotting the function. IMPORTANT, you need to make sure the local min and the start point of gradient descent are both within this range. Otherwise they won't be shown. You set the x (or y) range by giving the min, the increment, and the max. For example the value below means min=-10, increment=1 and max = 10;
agd.xrange = -10:1:10;
agd.yrange = -10:1:10;
Learning rate alpha is an important tuning factor when training a model. By using different values of alpha, animateGraDes can show how it affects the training. For example, if we change alpha from 0.1 to 0.2, we will get this contour plot instead.
agd.alpha = 0.2;
The first contour plot uses alpha=0.1, and this contour plot shows alpha=0.2. This contour plot shows the effect of overshooting.
Setting alpha greater than 0.3 for this cost function will make gradient descent diverge. It is getting away from the local min, and it will never reach there.
The default learning rate is set to 0.1, if not set.
Set an array of alphas
agd.alpha = [0.05 0.1 0.2];
If the starting point is too far away from the local min, it could take a long time for gradient descent to reach the local min. In addition, if there are more than one local mins, the starting point might decide which local min gradient descent is going to reach.
Starting point is an [x y] array:
agd.startPoint=[2 3];
By using different function and different starting point, animateGraDes can also show the effect of saddle points.
The following plots are generated using the example of this paper
agd = animateGraDes();
agd.alpha=0.15;
agd.funcStr='x^4-2*x^2+y^2'; % this the function with saddle points
agd.startPoint=[1.5 1.5];
agd.drawContour=false;
agd.xrange=-2:0.1:2;
agd.yrange=-2:0.1:2;
agd.animate();
A starting point of [1.5 1.5] avoids the saddle points and gradient decent can reach the local min [1 0] successfully.
Another starting point of [-1.5 -1.5]. It can also avoid the saddle point. But it ends up in another local min [-1 0].
agd.startPoint=[-1.5 -1.5];
Now change the starting point to [0 1.5]. This time the starting point is right on the ridge that separates those two local mins.
agd.startPoint=[0 1.5];
Now gradient descent stops at the saddle point [0 0]:
From physics point of view, a local min is a stable balance point, while a saddle point is an unstable balance point. Thus a small perturbation can break the balance of an unstable point. This can be used to escape the saddle point.
When gradient descent gets closer to the local min, the gradient is getting smaller and smaller. This means the step gets shorter when it gets closer to the local min. A stropThreshold can be set to stop gradient descent.
The default is 1E-10
. To get better accuracy set this to a smaller value.
agd.stopThreshold = 1E-16;
Keep in mind that a smaller stopThreshold can give beeter accuracy, but it will take (much) longer for gradient descent to reach the answer.
It is an optional flag, and default to false.
agd.fillContour = true
This optional parameter is used to output an animation gif file. It can be set like this:
agd.outfile='animateGraDes.gif';
By default, this is empty, and no animation gif will be generated. You can still save the resulted figure into a PNG file.
Citation pour cette source
Yongjian Feng (2024). Animated-Gradient-Descent (https://github.com/MATLAB-Graphics-and-App-Building/Animated-Gradient-Descent/releases/tag/v1.1), GitHub. Extrait(e) le .
Compatibilité avec les versions de MATLAB
Plateformes compatibles
Windows macOS LinuxTags
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!Découvrir Live Editor
Créez des scripts avec du code, des résultats et du texte formaté dans un même document exécutable.
Version | Publié le | Notes de version | |
---|---|---|---|
1.1 | See release notes for this release on GitHub: https://github.com/MATLAB-Graphics-and-App-Building/Animated-Gradient-Descent/releases/tag/v1.1 |
||
1.0.1 | See release notes for this release on GitHub: https://github.com/MATLAB-Graphics-and-App-Building/Animated-Gradient-Descent/releases/tag/v1.0.1 |
||
1.0 |
|