Determine stacking coefficients from a data stack
Source:R/blend_predictions.R
blend_predictions.Rd
Evaluates a data stack by fitting a regularized model on the assessment predictions from each candidate member to predict the true outcome.
This process determines the "stacking coefficients" of the model stack. The stacking coefficients are used to weight the predictions from each candidate (represented by a unique column in the data stack), and are given by the betas of a LASSO model fitting the true outcome with the predictions given in the remaining columns of the data stack.
Candidates with non-zero stacking coefficients are model stack
members, and need to be trained on the full training set (rather
than just the assessment set) with fit_members()
. This function
is typically used after a number of calls to add_candidates()
.
Usage
blend_predictions(
data_stack,
penalty = 10^(-6:-1),
mixture = 1,
non_negative = TRUE,
metric = NULL,
control = tune::control_grid(),
times = 25,
...
)
Arguments
- data_stack
A
data_stack
object- penalty
A numeric vector of proposed values for total amount of regularization used in member weighting. Higher penalties will generally result in fewer members being included in the resulting model stack, and vice versa. The package will tune over a grid formed from the cross product of the
penalty
andmixture
arguments.- mixture
A number between zero and one (inclusive) giving the proportion of L1 regularization (i.e. lasso) in the model.
mixture = 1
indicates a pure lasso model,mixture = 0
indicates ridge regression, and values in(0, 1)
indicate an elastic net. The package will tune over a grid formed from the cross product of thepenalty
andmixture
arguments.- non_negative
A logical giving whether to restrict stacking coefficients to non-negative values. If
TRUE
(default), 0 is passed as thelower.limits
argument toglmnet::glmnet()
in fitting the model on the data stack. Otherwise,-Inf
.- metric
A call to
yardstick::metric_set()
. The metric(s) to use in tuning the lasso penalty on the stacking coefficients. Default values are determined bytune::tune_grid()
from the outcome class.- control
An object inheriting from
control_grid
to be passed to the model determining stacking coefficients. Seetune::control_grid()
documentation for details on possible values. Note that anyextract
entry will be overwritten internally.- times
Number of bootstrap samples tuned over by the model that determines stacking coefficients. See
rsample::bootstraps()
to learn more.- ...
Additional arguments. Currently ignored.
Value
A model_stack
object—while model_stack
s largely contain the
same elements as data_stack
s, the primary data objects shift from the
assessment set predictions to the member models.
Details
Note that a regularized linear model is one of many possible
learning algorithms that could be used to fit a stacked ensemble
model. For implementations of additional ensemble learning algorithms, see
h2o::h2o.stackedEnsemble()
and SuperLearner::SuperLearner()
.
Example Data
This package provides some resampling objects and datasets for use in examples and vignettes derived from a study on 1212 red-eyed tree frog embryos!
Red-eyed tree frog (RETF) embryos can hatch earlier than their normal 7ish days if they detect potential predator threat. Researchers wanted to determine how, and when, these tree frog embryos were able to detect stimulus from their environment. To do so, they subjected the embryos at varying developmental stages to "predator stimulus" by jiggling the embryos with a blunt probe. Beforehand, though some of the embryos were treated with gentamicin, a compound that knocks out their lateral line (a sensory organ.) Researcher Julie Jung and her crew found that these factors inform whether an embryo hatches prematurely or not!
Note that the data included with the stacks package is not necessarily a representative or unbiased subset of the complete dataset, and is only for demonstrative purposes.
reg_folds
and class_folds
are rset
cross-fold validation objects
from rsample
, splitting the training data into for the regression
and classification model objects, respectively. tree_frogs_reg_test
and
tree_frogs_class_test
are the analogous testing sets.
reg_res_lr
, reg_res_svm
, and reg_res_sp
contain regression tuning results
for a linear regression, support vector machine, and spline model, respectively,
fitting latency
(i.e. how long the embryos took to hatch in response
to the jiggle) in the tree_frogs
data, using most all of the other
variables as predictors. Note that the data underlying these models is
filtered to include data only from embryos that hatched in response to
the stimulus.
class_res_rf
and class_res_nn
contain multiclass classification tuning
results for a random forest and neural network classification model,
respectively, fitting reflex
(a measure of ear function) in the
data using most all of the other variables as predictors.
log_res_rf
and log_res_nn
, contain binary classification tuning results
for a random forest and neural network classification model, respectively,
fitting hatched
(whether or not the embryos hatched in response
to the stimulus) using most all of the other variables as predictors.
See ?example_data
to learn more about these objects, as well as browse
the source code that generated them.
See also
Other core verbs:
add_candidates()
,
fit_members()
,
stacks()
Examples
# see the "Example Data" section above for
# clarification on the objects used in these examples!
# put together a data stack
reg_st <-
stacks() %>%
add_candidates(reg_res_lr) %>%
add_candidates(reg_res_svm) %>%
add_candidates(reg_res_sp)
reg_st
#> # A data stack with 3 model definitions and 16 candidate members:
#> # reg_res_lr: 1 model configuration
#> # reg_res_svm: 5 model configurations
#> # reg_res_sp: 10 model configurations
#> # Outcome: latency (numeric)
# evaluate the data stack
reg_st %>%
blend_predictions()
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# include fewer models by proposing higher penalties
reg_st %>%
blend_predictions(penalty = c(.5, 1))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.620
#> 2 reg_res_sp_03_1 linear_reg 0.472
#> 3 reg_res_sp_10_1 linear_reg 0.0517
#>
#> Members have not yet been fitted with `fit_members()`.
# allow for negative stacking coefficients
# with the non_negative argument
reg_st %>%
blend_predictions(non_negative = FALSE)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 12.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 10 highest weighted members are:
#> # A tibble: 10 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_1 svm_rbf -10.5
#> 2 reg_res_sp_04_1 linear_reg -1.38
#> 3 reg_res_sp_05_1 linear_reg 1.35
#> 4 reg_res_svm_1_3 svm_rbf 1.19
#> 5 reg_res_svm_1_2 svm_rbf -0.963
#> 6 reg_res_sp_03_1 linear_reg 0.642
#> 7 reg_res_sp_01_1 linear_reg -0.400
#> 8 reg_res_sp_10_1 linear_reg 0.319
#> 9 reg_res_sp_06_1 linear_reg 0.193
#> 10 reg_res_lr_1_1 linear_reg 0.183
#>
#> Members have not yet been fitted with `fit_members()`.
# use a custom metric in tuning the lasso penalty
library(yardstick)
reg_st %>%
blend_predictions(metric = metric_set(rmse))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# pass control options for stack blending
reg_st %>%
blend_predictions(
control = tune::control_grid(allow_par = TRUE)
)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# to speed up the stacking process for preliminary
# results, bump down the `times` argument:
reg_st %>%
blend_predictions(times = 5)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# the process looks the same with
# multinomial classification models
class_st <-
stacks() %>%
add_candidates(class_res_nn) %>%
add_candidates(class_res_rf) %>%
blend_predictions()
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
class_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 21 possible candidate members, the ensemble retained 8.
#> Penalty: 0.01.
#> Mixture: 1.
#> Across the 3 classes, there are an average of 4 coefficients per class.
#>
#> The 8 highest weighted member classes are:
#> # A tibble: 8 × 4
#> member type weight class
#> <chr> <chr> <dbl> <fct>
#> 1 .pred_full_class_res_nn_1_1 mlp 23.3 full
#> 2 .pred_mid_class_res_nn_1_1 mlp 1.89 mid
#> 3 .pred_mid_class_res_rf_1_06 rand_forest 1.71 mid
#> 4 .pred_mid_class_res_rf_1_10 rand_forest 1.17 mid
#> 5 .pred_full_class_res_rf_1_03 rand_forest 0.407 full
#> 6 .pred_full_class_res_rf_1_05 rand_forest 0.222 full
#> 7 .pred_full_class_res_rf_1_01 rand_forest 0.00160 full
#> 8 .pred_full_class_res_rf_1_02 rand_forest 0.000322 full
#>
#> Members have not yet been fitted with `fit_members()`.
# ...or binomial classification models
log_st <-
stacks() %>%
add_candidates(log_res_nn) %>%
add_candidates(log_res_rf) %>%
blend_predictions()
log_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 11 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-04.
#> Mixture: 1.
#>
#> The 3 highest weighted member classes are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 .pred_no_log_res_nn_1_1 mlp 7.41
#> 2 .pred_no_log_res_rf_1_05 rand_forest 3.44
#> 3 .pred_no_log_res_rf_1_02 rand_forest 0.0638
#>
#> Members have not yet been fitted with `fit_members()`.