Surajustement

Surajustement

Le surajustement (overfitting) est un comportement de Machine Learning qui se produit lorsque le modèle est si étroitement aligné avec les données d'apprentissage qu'il ne sait pas comment réagir à de nouvelles données. Le surajustement peut se produire pour les raisons suivantes :

  • Le modèle de Machine Learning est trop complexe. Il mémorise des patterns très subtils dans les données d'apprentissage, qui ne se prêtent pas bien à la généralisation.
  • Le volume des données d'apprentissage est trop faible pour la complexité du modèle et/ou les données contiennent de grandes quantités d'informations non pertinentes.

Vous pouvez éviter le surajustement en gérant la complexité du modèle et en améliorant le jeu de données d'apprentissage.

Surajustement et sous-ajustement

Le sous-ajustement (underfitting) est le concept inverse du surajustement : le modèle ne s'aligne pas bien avec les données d'apprentissage ou ne se généralise pas bien à de nouvelles données. Le surajustement et le sous-ajustement peuvent être présents à la fois dans les modèles de classification et de régression. La figure suivante illustre la manière dont la limite de décision de classification et la ligne de régression suivent les données d'apprentissage, de trop près pour un modèle surajusté, et pas assez près pour un modèle sous-ajusté.

Graphiques de données illustrant un surajustement, un apprentissage correct et un sous-ajustement pour des modèles de classification et de régression.

Les modèles de classification et de régression surajustés mémorisent trop bien les données d'apprentissage par rapport aux modèles correctement ajustés.

Si l'on considère uniquement l'erreur calculée d'un modèle de Machine Learning pour les données d'apprentissage, le surajustement est plus difficile à détecter que le sous-ajustement. Ainsi, pour éviter le surajustement, il est important de valider un modèle de Machine Learning avant de l'utiliser sur des données de test.

Erreur

Surajustement

Ajustement correct

Sous-ajustement

Apprentissage

Faible

Faible

Élevée

Test

Élevée

Faible

Élevée

L'erreur calculée pour les modèles surajustés est faible sur les données d'apprentissage, mais élevée sur les données de test.

Utiliser MATLAB® avec Statistics and Machine Learning Toolbox™ et Deep Learning Toolbox™ peut vous aider à éviter le surajustement des modèles de Machine Learning et de Deep Learning. MATLAB propose des fonctions et des méthodes spécifiquement conçues pour éviter le surajustement des modèles. Vous pouvez utiliser ces outils lorsque vous entraînez ou ajustez votre modèle pour le protéger contre le surajustement.

Comment éviter le surajustement en réduisant la complexité du modèle

Avec MATLAB, vous pouvez entraîner des modèles de Machine Learning et de Deep Learning (tels que des CNN) en partant de zéro, ou bien tirer parti de modèles de Deep Learning pré-entraînés. Afin d'éviter le surajustement, effectuez une validation du modèle pour vous assurer que vous choisissez un modèle ayant le bon niveau de complexité pour vos données, ou utilisez la régularisation pour réduire la complexité du modèle.

Validation du modèle

L'erreur d'un modèle surajusté est faible lorsqu'elle est calculée pour les données d'apprentissage. Une bonne pratique consiste à valider votre modèle sur un jeu de données distinct (c'est-à-dire un jeu de données de validation) avant d'introduire de nouvelles données. Pour les modèles de Machine Learning MATLAB, vous pouvez utiliser la fonction cvpartition afin de partitionner de manière aléatoire un jeu de données en jeux de données d'apprentissage et de validation. Pour les modèles de Deep Learning, vous pouvez surveiller la précision de la validation pendant l'apprentissage. L'amélioration de la mesure de précision correctement validée pour vos modèles, obtenue par la sélection de modèles et le réglage des hyperparamètres, devrait se traduire par une meilleure précision lorsque le modèle est confronté à de nouvelles données.

La validation croisée est une technique d'évaluation de modèle utilisée pour déterminer les performances d'un algorithme de Machine Learning lorsqu'il réalise des prédictions à partir de jeux de données sur lesquels il n'a pas été entraîné. La validation croisée vous aide à choisir un algorithme dont la complexité modérée ne risque pas de provoquer un surajustement. Utilisez la fonction crossval pour calculer l'estimation de l'erreur de validation croisée pour les modèles de Machine Learning en utilisant des techniques de validation croisée courantes, telles que la validation k-fold (partitionnement des données en k sous-ensembles choisis de manière aléatoire et de taille à peu près égale) et le holdout (partitionnement aléatoire des données en exactement deux sous-ensembles selon le ratio spécifié).

Régularisation

La régularisation est une technique utilisée pour éviter le surajustement statistique dans un modèle de Machine Learning. Les algorithmes de régularisation fonctionnent généralement par application d'une pénalité liée à la complexité ou l'irrégularité. En introduisant des informations supplémentaires dans le modèle, les algorithmes de régularisation peuvent gérer la multicolinéarité et les prédicteurs redondants pour rendre ainsi le modèle plus parcimonieux et plus précis.

Dans le cadre du Machine Learning, vous avez le choix entre trois techniques de régularisation populaires : Lasso (norme L1), Ridge (norme L2) et Elastic Net, avec plusieurs types de modèles de Machine Learning linéaires. Dans le cadre du Deep Learning, vous pouvez augmenter le facteur de régularisation L2 dans les options d'apprentissage spécifiées ou utiliser des couches de dropout dans votre réseau pour éviter le surajustement.

Comment éviter le surajustement en améliorant le jeu de données d'apprentissage

La validation croisée et la régularisation évitent le surajustement en gérant la complexité du modèle. Une autre approche consiste à améliorer le jeu de données. Les modèles de Deep Learning nécessitent en particulier de grandes quantités de données pour éviter le surajustement.

Augmentation des données

Lorsque la disponibilité des données est limitée, l'augmentation des données est une méthode permettant d'étendre artificiellement le jeu de données d'apprentissage en lui ajoutant des versions aléatoires des points de données existants. Avec MATLAB, vous pouvez augmenter des images, des données audio et d'autres types de données. Par exemple, vous pouvez augmenter des images en randomisant l'échelle et la rotation des images existantes.

Génération de données

La génération de données synthétiques est une autre méthode permettant d'étendre un jeu de données. Avec MATLAB, vous pouvez générer des données synthétiques en utilisant des réseaux antagonistes génératifs (GAN) ou des jumeaux numériques (génération de données par simulation).

Nettoyage des données

Le bruit des données contribue au surajustement. Une approche courante permettant de réduire les points de données indésirables consiste à supprimer les valeurs aberrantes des données grâce à la fonction rmoutliers.

Navigation dans l'interface