Train an ensemble (stacked) model from the base learners in a
caret_list. The ensemble is itself a caret::train model that learns to
combine the predictions of the base models. By default, the meta-learner is
trained on out-of-fold predictions from the resampling process, ensuring that
the ensemble does not overfit to in-sample predictions. Alternatively, new
datasets can be supplied via data_list and target for transfer-learning
style ensembling.
Usage
caret_stack(
caret_list,
method,
data_list = NULL,
target = NULL,
trControl = NULL,
metric = NULL,
...
)Arguments
- caret_list
a
caret_listobject- method
The method to train the ensemble model. Can be a custom method or one found in
caret::modelLookup().- data_list
A list of datasets to predict on, with each dataset matching the corresponding model in
caret_list. IfNULL, the out-of-fold predictions from the base models will be used. Default isNULL.- target
Target parameter vector that must be provided if predicting on a new data list. If
NULL, the target vector used to train the base models will be used.- trControl
Control for use with the
caret::trainfunction. IfNULL,.default_control()is used to construct a default control (5-fold cross validation) depending on the target type. Default isNULL.- metric
Metric for use with
caret::trainfunction. IfNULL,.default_metric()is used to construct a default metric depending on the target type. Default isNULL.- ...
Additional arguments to pass to
caret::train
Examples
set.seed(42)
data(heart_failure_datasets)
data_list <- heart_failure_datasets[c("cells", "holter", "mrna", "proteins")]
# Define hyperparameters to tune (optional)
tuneGrid <- expand.grid(alpha = 0.5, lambda = c(0.01, 0.1))
# Construct caret_list object
base_models <- caret_list(
target = heart_failure_datasets$demo$hospitalizations,
data_list = data_list,
method = "glmnet",
tuneGrid = tuneGrid
)
# Train a Random Forest stacked model on the out-of-fold predictions from the base models
stacked_model <- caret_stack(
caret_list = base_models,
method = "rf"
)
class(stacked_model)
#> [1] "caret_stack"