Surajustement
Le surajustement (overfitting) est un comportement de Machine Learning qui se produit lorsque le modèle est si étroitement aligné avec les données d'apprentissage qu'il ne sait pas comment réagir à de nouvelles données. Le surajustement peut se produire pour les raisons suivantes :
- Le modèle de Machine Learning est trop complexe. Il mémorise des patterns très subtils dans les données d'apprentissage, qui ne se prêtent pas bien à la généralisation.
- Le volume des données d'apprentissage est trop faible pour la complexité du modèle et/ou les données contiennent de grandes quantités d'informations non pertinentes.
Vous pouvez éviter le surajustement en gérant la complexité du modèle et en améliorant le jeu de données d'apprentissage.
Surajustement et sous-ajustement
Le sous-ajustement (underfitting) est le concept inverse du surajustement : le modèle ne s'aligne pas bien avec les données d'apprentissage ou ne se généralise pas bien à de nouvelles données. Le surajustement et le sous-ajustement peuvent être présents à la fois dans les modèles de classification et de régression. La figure suivante illustre la manière dont la limite de décision de classification et la ligne de régression suivent les données d'apprentissage, de trop près pour un modèle surajusté, et pas assez près pour un modèle sous-ajusté.
Si l'on considère uniquement l'erreur calculée d'un modèle de Machine Learning pour les données d'apprentissage, le surajustement est plus difficile à détecter que le sous-ajustement. Ainsi, pour éviter le surajustement, il est important de valider un modèle de Machine Learning avant de l'utiliser sur des données de test.
Erreur |
Surajustement |
Ajustement correct |
Sous-ajustement |
Apprentissage |
Faible |
Faible |
Élevée |
Test |
Élevée |
Faible |
Élevée |
Utiliser MATLAB® avec Statistics and Machine Learning Toolbox™ et Deep Learning Toolbox™ peut vous aider à éviter le surajustement des modèles de Machine Learning et de Deep Learning. MATLAB propose des fonctions et des méthodes spécifiquement conçues pour éviter le surajustement des modèles. Vous pouvez utiliser ces outils lorsque vous entraînez ou ajustez votre modèle pour le protéger contre le surajustement.
Comment éviter le surajustement en réduisant la complexité du modèle
Avec MATLAB, vous pouvez entraîner des modèles de Machine Learning et de Deep Learning (tels que des CNN) en partant de zéro, ou bien tirer parti de modèles de Deep Learning pré-entraînés. Afin d'éviter le surajustement, effectuez une validation du modèle pour vous assurer que vous choisissez un modèle ayant le bon niveau de complexité pour vos données, ou utilisez la régularisation pour réduire la complexité du modèle.
Validation du modèle
L'erreur d'un modèle surajusté est faible lorsqu'elle est calculée pour les données d'apprentissage. Une bonne pratique consiste à valider votre modèle sur un jeu de données distinct (c'est-à-dire un jeu de données de validation) avant d'introduire de nouvelles données. Pour les modèles de Machine Learning MATLAB, vous pouvez utiliser la fonction cvpartition
afin de partitionner de manière aléatoire un jeu de données en jeux de données d'apprentissage et de validation. Pour les modèles de Deep Learning, vous pouvez surveiller la précision de la validation pendant l'apprentissage. L'amélioration de la mesure de précision correctement validée pour vos modèles, obtenue par la sélection de modèles et le réglage des hyperparamètres, devrait se traduire par une meilleure précision lorsque le modèle est confronté à de nouvelles données.
La validation croisée est une technique d'évaluation de modèle utilisée pour déterminer les performances d'un algorithme de Machine Learning lorsqu'il réalise des prédictions à partir de jeux de données sur lesquels il n'a pas été entraîné. La validation croisée vous aide à choisir un algorithme dont la complexité modérée ne risque pas de provoquer un surajustement. Utilisez la fonction crossval
pour calculer l'estimation de l'erreur de validation croisée pour les modèles de Machine Learning en utilisant des techniques de validation croisée courantes, telles que la validation k-fold (partitionnement des données en k sous-ensembles choisis de manière aléatoire et de taille à peu près égale) et le holdout (partitionnement aléatoire des données en exactement deux sous-ensembles selon le ratio spécifié).
Régularisation
La régularisation est une technique utilisée pour éviter le surajustement statistique dans un modèle de Machine Learning. Les algorithmes de régularisation fonctionnent généralement par application d'une pénalité liée à la complexité ou l'irrégularité. En introduisant des informations supplémentaires dans le modèle, les algorithmes de régularisation peuvent gérer la multicolinéarité et les prédicteurs redondants pour rendre ainsi le modèle plus parcimonieux et plus précis.
Dans le cadre du Machine Learning, vous avez le choix entre trois techniques de régularisation populaires : Lasso (norme L1), Ridge (norme L2) et Elastic Net, avec plusieurs types de modèles de Machine Learning linéaires. Dans le cadre du Deep Learning, vous pouvez augmenter le facteur de régularisation L2 dans les options d'apprentissage spécifiées ou utiliser des couches de dropout dans votre réseau pour éviter le surajustement.
Exemples et démonstrations
Comment éviter le surajustement en améliorant le jeu de données d'apprentissage
La validation croisée et la régularisation évitent le surajustement en gérant la complexité du modèle. Une autre approche consiste à améliorer le jeu de données. Les modèles de Deep Learning nécessitent en particulier de grandes quantités de données pour éviter le surajustement.
Augmentation des données
Lorsque la disponibilité des données est limitée, l'augmentation des données est une méthode permettant d'étendre artificiellement le jeu de données d'apprentissage en lui ajoutant des versions aléatoires des points de données existants. Avec MATLAB, vous pouvez augmenter des images, des données audio et d'autres types de données. Par exemple, vous pouvez augmenter des images en randomisant l'échelle et la rotation des images existantes.
Génération de données
La génération de données synthétiques est une autre méthode permettant d'étendre un jeu de données. Avec MATLAB, vous pouvez générer des données synthétiques en utilisant des réseaux antagonistes génératifs (GAN) ou des jumeaux numériques (génération de données par simulation).
Nettoyage des données
Le bruit des données contribue au surajustement. Une approche courante permettant de réduire les points de données indésirables consiste à supprimer les valeurs aberrantes des données grâce à la fonction rmoutliers
.