predict
Syntax
Description
returns predicted class labels for the predictor data Ypred = predict(mdl,XNew)XNew and MultinomialRegression model object mdl.
[___] = predict(___,
specifies options using one or more name-value arguments in addition to any of the input
argument combinations in previous syntaxes. For example, you can specify the type of
probability for the probability estimates returned in Name=Value)probs.
Examples
Load the fisheriris sample data set.
load fisheririsThe column vector species contains three iris flowers species: setosa, versicolor, and virginica. The matrix meas contains four types of measurements for the flower: the length and width of sepals and petals in centimeters.
Divide the species and measurement data into training and test data by using the cvpartition function. Get the indices of the training data rows by using the training function.
n = length(species);
partition = cvpartition(n,'Holdout',0.05);
idx_train = training(partition);Create training data by using the indices of the training data rows to create a matrix of measurements and a vector of species labels.
meastrain = meas(idx_train,:); speciestrain = species(idx_train,:);
Fit a multinomial regression model using the training data.
mdl = fitmnr(meastrain,speciestrain)
mdl =
Multinomial regression with nominal responses
Value SE tStat pValue
_______ ______ ________ __________
(Intercept_setosa) 86.297 12.541 6.881 5.9436e-12
x1_setosa -1.0653 3.5795 -0.29761 0.766
x2_setosa 23.849 3.1238 7.6347 2.2637e-14
x3_setosa -27.273 3.5009 -7.7903 6.6846e-15
x4_setosa -59.644 7.0214 -8.4947 1.9852e-17
(Intercept_versicolor) 42.637 5.2214 8.1659 3.1906e-16
x1_versicolor 2.4652 1.1263 2.1887 0.028619
x2_versicolor 6.6808 1.474 4.5325 5.829e-06
x3_versicolor -9.4292 1.2946 -7.2837 3.248e-13
x4_versicolor -18.286 2.0833 -8.7775 1.671e-18
143 observations, 276 error degrees of freedom
Dispersion: 1
Chi^2-statistic vs. constant model: 302.0378, p-value = 1.5168e-60
mdl is a multinomial regression model object that contains the results of fitting a nominal multinomial regression model to the data. The table output shows coefficient statistics for each predictor in meas. By default, fitmnr uses virginica as the reference category.
Get the indices of the test data rows by using the test function. Create test data by using the indices of the test data rows to create a matrix of measurements and a vector of species labels.
idx_test = test(partition); meastest = meas(idx_test,:); speciestest = species(idx_test,:);
Predict the iris species for the measurements in meastest.
speciespredict = predict(mdl,meastest)
speciespredict = 7×1 cell
{'setosa' }
{'setosa' }
{'setosa' }
{'setosa' }
{'setosa' }
{'versicolor'}
{'versicolor'}
Compare the predictions in speciespredict with the category names in speciestest.
speciestest
speciestest = 7×1 cell
{'setosa' }
{'setosa' }
{'setosa' }
{'setosa' }
{'setosa' }
{'versicolor'}
{'versicolor'}
The output shows that the model accurately predicts the iris species for the measurements in meastest.
Load the carbig sample data set.
load carbigThe variables Acceleration and Displacement contain data for car acceleration and displacement, respectively. The variable Cylinders contains data for the number of cylinders in each car engine.
Create a table from the car data variables using the table function.
tbl = table(Acceleration,Displacement,Cylinders,VariableNames=["Acceleration","Displacement","Cylinders"])
tbl=406×3 table
Acceleration Displacement Cylinders
____________ ____________ _________
12 307 8
11.5 350 8
11 318 8
12 304 8
10.5 302 8
10 429 8
9 454 8
8.5 440 8
10 455 8
8.5 390 8
17.5 133 4
11.5 350 8
11 351 8
10.5 383 8
11 360 8
10 383 8
⋮
The Cylinders data has an inherent ordering. Fit an ordinal multinomial regression model using Acceleration and Displacement as predictor variables and Cylinders as the response.
mdl = fitmnr(tbl,"Cylinders",ModelType="ordinal");
mdl is a multinomial regression model object that contains the results of fitting an ordinal multinomial regression model to the data.
Predict the response category, cumulative category probabilities, and 99% confidence interval bounds for a car with an acceleration of 16 and an engine displacement of 80.
[cylinderspredict,cumprobs,lower,upper] = predict(mdl,[16 80],Alpha=0.01,ProbabilityType="cumulative")cylinderspredict = 4
cumprobs = 1×4
0.0792 1.0000 1.0000 1.0000
lower = 1×4
0.0787 1.0000 1.0000 1.0000
upper = 1×4
0.0798 1.0000 1.0000 1.0000
The output shows that the predicted response category is 4. The vector cumprobs shows the cumulative probabilities for each category in Cylinders. To view the category probabilities on which the prediction is based, calculate the category probabilities.
[~,catprobs] = predict(mdl,[16 80])
catprobs = 1×5
0.0792 0.9208 0.0000 0.0000 0.0000
The second value in the vector catprobs has the highest probability. Display an ordered list of the categories in Cylinders.
mdl.ClassNames
ans = 5×1
3
4
5
6
8
The output shows that the second category corresponds to cars with four cylinders. Therefore, the category with the highest category probability is 4.
Input Arguments
Multinomial regression model object, specified as a MultinomialRegression model object created with the fitmnr
function.
New predictor input values, specified
as a table or an n-by-p matrix, where n
is the number of observations to predict, and p is the number of predictor
variables used to fit mdl.
If
XNewis a table, it must contain all the names of the predictors used to fitmdl. You can find the predictor names in themdl.PredictorNamesproperty.If
XNewis a matrix, it must have the same number of columns as the number of estimated coefficients. You can find the number of estimated coefficients in themdl.NumPredictorsproperty. You can specifyXNewas a matrix only when all names inmdl.PredictorNamesrefer to numeric predictors.
Example: predict(mdl,[6.2 3.4; 5.9 3.0]) evaluates the
two-predictor model mdl at the points p1 = [6.2
3.4] and p2 = [5.9 3.0].
Data Types: single | double | table
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN, where Name is
the argument name and Value is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: [ specifies a 99%
confidence level for the probability estimates and their type as cumulative.Ypred,probs,lower,upper]
= predict(model,X,Alpha=0.01,ProbabilityType="cumulative")
Significance level for the probability estimates, specified as a scalar value in
the range (0,1). The confidence level of the confidence intervals is 100(1 − α)%. The default value for Alpha is
0.05, which returns 95% confidence intervals for the
estimates.
Example: Alpha=0.01
Data Types: single | double
Type of probability estimates to return in probs, specified
as one of the following options.
| Option | Description |
|---|---|
"category" (default) | Calculate a distinct probability for each response category. |
"cumulative" | Calculate a cumulative probability for each response category. |
"conditional" | Calculate a conditional probability for each response category. |
Example: ProbabilityType="conditional"
Data Types: char | string
Output Arguments
Predicted response categories, returned as a categorical or character array, logical
or numeric vector, or cell array of character vectors. Ypred has
the same data type as mdl.ClassNames.
Probability estimates for the response categories, returned as a numeric matrix.
Each column of probs corresponds to the entry at the same index in
mdl.ClassNames.
Upper confidence interval bound for the probability estimates in
probs, returned as a numeric matrix.
Lower confidence interval bound for the probability estimates in
probs, returned as a numeric matrix.
Alternative Functionality
Version History
Introduced in R2023a
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Sélectionner un site web
Choisissez un site web pour accéder au contenu traduit dans votre langue (lorsqu'il est disponible) et voir les événements et les offres locales. D’après votre position, nous vous recommandons de sélectionner la région suivante : .
Vous pouvez également sélectionner un site web dans la liste suivante :
Comment optimiser les performances du site
Pour optimiser les performances du site, sélectionnez la région Chine (en chinois ou en anglais). Les sites de MathWorks pour les autres pays ne sont pas optimisés pour les visites provenant de votre région.
Amériques
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)