Main Content

Train Neural Network Classifiers Using Classification Learner App

This example shows how to create and compare neural network classifiers in the Classification Learner app, and export trained models to the workspace to make predictions for new data.

  1. In the MATLAB® Command Window, load the fisheriris data set, and create a table from the variables in the data set to use for classification.

    fishertable = readtable('fisheriris.csv');
    
  2. Click the Apps tab, and then click the Show more arrow on the right to open the apps gallery. In the Machine Learning and Deep Learning group, click Classification Learner.

  3. On the Classification Learner tab, in the File section, click New Session and select From Workspace.

    Classification Learner tab

  4. In the New Session from Workspace dialog box, select the table fishertable from the Data Set Variable list (if necessary). Observe that the app has selected response and predictor variables based on their data types. Petal and sepal length and width are predictors, and species is the response that you want to classify. For this example, do not change the selections.

  5. To accept the default validation scheme and continue, click Start Session. The default validation option is 5-fold cross-validation, to protect against overfitting.

    Classification Learner creates a scatter plot of the data.

  6. Use the scatter plot to investigate which variables are useful for predicting the response. Select different options in the X and Y lists under Predictors to visualize the distribution of species and measurements. Note which variables separate the species colors most clearly.

  7. Create a selection of neural network models. On the Classification Learner tab, in the Model Type section, click the arrow to open the gallery. In the Neural Network Classifiers group, click All Neural Networks.

  8. In the Training section, click Train. Classification Learner trains one of each neural network classification option in the gallery. In the Models pane, the app outlines the Accuracy (Validation) score of the best model. Classification Learner also displays a validation confusion matrix for the first neural network model (Narrow Neural Network).

    Tip

    If you have Parallel Computing Toolbox™, you can train all the models (All Neural Networks) simultaneously by selecting the Use Parallel button in the Training section before clicking Train. After you click Train, the Opening Parallel Pool dialog box opens and remains open while the app opens a parallel pool of workers. During this time, you cannot interact with the software. After the pool opens, the app trains the models simultaneously.

  9. Select a model in the Models pane to view the results. For example, select the Narrow Neural Network model (model 1.1). Inspect the Current Model Summary pane. The Current Model Summary pane displays the Training Results metrics, calculated on the validation set.

  10. Examine the scatter plot for the trained model. On the Classification Learner tab, in the Plots section, click the arrow to open the gallery, and then click Scatter in the Validation Results group. Correctly classified points are marked with an O, and incorrectly classified points are marked with an X.

    Scatter plot of the Fisher iris data modeled by a neural network classifier

    Note

    Validation introduces some randomness into the results. Your model validation results can vary from the results shown in this example.

  11. Inspect the accuracy of the predictions in each class. On the Classification Learner tab, in the Plots section, click the arrow to open the gallery, and then click Confusion Matrix (Validation) in the Validation Results group. View the matrix of true class and predicted class results.

  12. Select the other models in the Models pane, open the validation confusion matrix for each of the models, and then compare the results.

  13. Choose the best model in the Models pane (the best score is highlighted in the Accuracy (Validation) box). To improve the model, try including different features in the model. See if you can improve the model by removing features with low predictive power.

    On the Classification Learner tab, in the Features section, click Feature Selection. In the Feature Selection dialog box, specify predictors to remove from the model, and click OK. In the Training section, click Train to train a new model using the new options. Compare results among the classifiers in the Models pane.

  14. To investigate features to include or exclude, use the scatter and parallel coordinates plots. On the Classification Learner tab, in the Plots section, click the arrow to open the gallery, and click Parallel Coordinates in the Validation Results group.

  15. Choose the best model in the Models pane. To try to improve the model further, change its advanced settings. On the Classification Learner tab, in the Model Type section, click Advanced and select Advanced. In the Advanced Neural Network Options dialog box, try changing some of the settings, like the sizes of the fully connected layers or the regularization strength, and click OK. Train the new model by clicking Train in the Training section.

    To learn more about neural network model settings, see Neural Network Classifiers.

  16. You can export a full or compact version of the trained model to the workspace. On the Classification Learner tab, in the Export section, click Export Model and select either Export Model or Export Compact Model. See Export Classification Model to Predict New Data.

  17. To examine the code for training this classifier, click Generate Function in the Export section.

Tip

Use the same workflow to evaluate and compare the other classifier types you can train in Classification Learner.

To train all the nonoptimizable classifier model presets available for your data set:

  1. On the Classification Learner tab, in the Model Type section, click the arrow to open the gallery of models.

  2. In the Get Started group, click All. Then, in the Training section, click Train.

    Option selected for training all available classifier types

To learn about other classifier types, see Train Classification Models in Classification Learner App.

Related Topics