cvshrink
Cross-validate pruning and regularization of regression ensemble
Description
returns an vals
= cvshrink(ens
)L
-by-T
matrix with cross-validated
values of the mean squared error. L
is the number of
Lambda
values in the ens.Regularization
structure. T
is the number of
Threshold
values on weak learner weights. If
ens
does not have a Regularization
property containing values specified by the regularize
function, set the Lambda
name-value
argument.
[___]
= cvshrink(
specifies additional options using one or more name-value arguments. For example,
you can specify the number of folds to use, the fraction of data to use for holdout
validation, and lower cutoffs on weights for weak learners.ens
,Name=Value
)
Examples
Cross-Validate Regression Ensemble
Create a regression ensemble for predicting mileage from the carsmall
data. Cross-validate the ensemble.
Load the carsmall
data set and select displacement, horsepower, and vehicle weight as predictors.
load carsmall
X = [Displacement Horsepower Weight];
You can train an ensemble of bagged regression trees.
ens = fitrensemble(X,Y,Method="Bag")
fircensemble
uses a default template tree object templateTree()
as a weak learner when 'Method'
is 'Bag'
. In this example, for reproducibility, specify 'Reproducible',true
when you create a tree template object, and then use the object as a weak learner.
rng('default') % For reproducibility t = templateTree(Reproducible=true); % For reproducibiliy of random predictor selections ens = fitrensemble(X,MPG,Method="Bag",Learners=t);
Specify values for Lambda
and Threshold
. Use these values to cross-validate the ensemble.
[vals,nlearn] = cvshrink(ens,Lambda=[.01 .1 1],Threshold=[0 .01 .1])
vals = 3×3
18.9150 19.0092 128.5935
18.9099 18.9504 128.8449
19.0328 18.9636 116.8500
nlearn = 3×3
13.7000 11.6000 4.1000
13.7000 11.7000 4.1000
13.9000 11.6000 4.1000
Clearly, setting a threshold of 0.1
leads to unacceptable errors, while a threshold of 0.01
gives similar errors to a threshold of 0
. The mean number of learners with a threshold of 0.01
is about 11.4
, whereas the mean number is about 13.8
when the threshold is 0
.
Input Arguments
ens
— Regression ensemble model
RegressionEnsemble
model object | RegressionBaggedEnsemble
model object
Regression ensemble model, specified as a RegressionEnsemble
or RegressionBaggedEnsemble
model object trained with fitrensemble
.
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Example: cvshrink(ens,Holdout=0.1,Threshold=[0 .01 .1])
specifies to reserve 10% of the data for holdout validation, and weight cutoffs of
0, 0.01, and 1 for the first, second, and third weak learners,
respectively.
CVPartition
— Cross-validation partition
[]
(default) | cvpartition
object
Cross-validation partition, specified as a cvpartition
object that specifies the type of cross-validation and the
indexing for the training and validation sets.
To create a cross-validated model, you can specify only one of these four name-value
arguments: CVPartition
, Holdout
,
KFold
, or Leaveout
.
Example: Suppose you create a random partition for 5-fold cross-validation on 500
observations by using cvp = cvpartition(500,KFold=5)
. Then, you can
specify the cross-validation partition by setting
CVPartition=cvp
.
Holdout
— Fraction of data for holdout validation
scalar value in the range (0,1)
Fraction of the data used for holdout validation, specified as a scalar value in the range
(0,1). If you specify Holdout=p
, then the software completes these
steps:
Randomly select and reserve
p*100
% of the data as validation data, and train the model using the rest of the data.Store the compact trained model in the
Trained
property of the cross-validated model.
To create a cross-validated model, you can specify only one of these four name-value
arguments: CVPartition
, Holdout
,
KFold
, or Leaveout
.
Example: Holdout=0.1
Data Types: double
| single
KFold
— Number of folds
10
(default) | positive integer value greater than 1
Number of folds to use in the cross-validated model, specified as a positive integer value
greater than 1. If you specify KFold=k
, then the software completes
these steps:
Randomly partition the data into
k
sets.For each set, reserve the set as validation data, and train the model using the other
k
– 1 sets.Store the
k
compact trained models in ak
-by-1 cell vector in theTrained
property of the cross-validated model.
To create a cross-validated model, you can specify only one of these four name-value
arguments: CVPartition
, Holdout
,
KFold
, or Leaveout
.
Example: KFold=5
Data Types: single
| double
Lambda
— Regularization parameter values
"[]"
(default) | vector of nonnegative scalar values
Regularization parameter values for lasso, specified as a vector of
nonnegative scalar values. If the value of this argument is empty,
cvshrink
does not perform
cross-validation.
Example: Lambda=[.01 .1 1]
Data Types: single
| double
Leaveout
— Leave-one-out cross-validation flag
"off"
(default) | "on"
Leave-one-out cross-validation flag, specified as "on"
or
"off"
. If you specify Leaveout="on"
, then for
each of the n observations (where n is the number
of observations, excluding missing observations, specified in the
NumObservations
property of the model), the software completes
these steps:
Reserve the one observation as validation data, and train the model using the other n – 1 observations.
Store the n compact trained models in an n-by-1 cell vector in the
Trained
property of the cross-validated model.
To create a cross-validated model, you can specify only one of these four name-value
arguments: CVPartition
, Holdout
,
KFold
, or Leaveout
.
Example: Leaveout="on"
Data Types: char
| string
Threshold
— Weights threshold
0 (default) | numeric vector
Weights threshold, specified as a numeric vector with lower cutoffs on
weights for weak learners. cvshrink
discards learners with weights below Threshold
in its
cross-validation calculation.
Example: Threshold=[0 .01 .1]
Data Types: single
| double
Output Arguments
vals
— Cross-validated values of mean squared error
numeric matrix
Cross-validated values of the mean squared error, returned as an
L
-by-T
numeric matrix.
L
is the number of values of the regularization
parameter Lambda
, and T
is the number
of Threshold
values on weak learner weights.
nlearn
— Mean number of learners
numeric matrix
Mean number of learners in the cross-validated ensemble, returned as an
L
-by-T
numeric matrix.
L
is the number of values of the regularization
parameter Lambda
, and T
is the number
of Threshold
values on weak learner weights.
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
This function fully supports GPU arrays. For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2011a
See Also
regularize
| shrink
| RegressionEnsemble
| RegressionBaggedEnsemble
| fitrensemble
Commande MATLAB
Vous avez cliqué sur un lien qui correspond à cette commande MATLAB :
Pour exécuter la commande, saisissez-la dans la fenêtre de commande de MATLAB. Les navigateurs web ne supportent pas les commandes MATLAB.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)