Skip to content

Apply a model stack to create different types of predictions.

Usage

# S3 method for class 'model_stack'
predict(object, new_data, type = NULL, members = FALSE, opts = list(), ...)

Arguments

object

A model stack with fitted members outputted from fit_members().

new_data

A rectangular data object, such as a data frame.

type

Format of returned predicted values—one of "numeric", "class", or "prob". When NULL, predict() will choose an appropriate value based on the model's mode.

members

Logical. Whether or not to additionally return the predictions for each of the ensemble members.

opts

A list of optional arguments to the underlying predict function passed on to parsnip::predict.model_fit for each member.

...

Additional arguments. Currently ignored.

Example Data

This package provides some resampling objects and datasets for use in examples and vignettes derived from a study on 1212 red-eyed tree frog embryos!

Red-eyed tree frog (RETF) embryos can hatch earlier than their normal 7ish days if they detect potential predator threat. Researchers wanted to determine how, and when, these tree frog embryos were able to detect stimulus from their environment. To do so, they subjected the embryos at varying developmental stages to "predator stimulus" by jiggling the embryos with a blunt probe. Beforehand, though some of the embryos were treated with gentamicin, a compound that knocks out their lateral line (a sensory organ.) Researcher Julie Jung and her crew found that these factors inform whether an embryo hatches prematurely or not!

Note that the data included with the stacks package is not necessarily a representative or unbiased subset of the complete dataset, and is only for demonstrative purposes.

reg_folds and class_folds are rset cross-fold validation objects from rsample, splitting the training data into for the regression and classification model objects, respectively. tree_frogs_reg_test and tree_frogs_class_test are the analogous testing sets.

reg_res_lr, reg_res_svm, and reg_res_sp contain regression tuning results for a linear regression, support vector machine, and spline model, respectively, fitting latency (i.e. how long the embryos took to hatch in response to the jiggle) in the tree_frogs data, using most all of the other variables as predictors. Note that the data underlying these models is filtered to include data only from embryos that hatched in response to the stimulus.

class_res_rf and class_res_nn contain multiclass classification tuning results for a random forest and neural network classification model, respectively, fitting reflex (a measure of ear function) in the data using most all of the other variables as predictors.

log_res_rf and log_res_nn, contain binary classification tuning results for a random forest and neural network classification model, respectively, fitting hatched (whether or not the embryos hatched in response to the stimulus) using most all of the other variables as predictors.

See ?example_data to learn more about these objects, as well as browse the source code that generated them.

Examples


# see the "Example Data" section above for
# clarification on the data and tuning results
# objects used in these examples!

data(tree_frogs_reg_test)
data(tree_frogs_class_test)

# build and fit a regression model stack
reg_st <-
  stacks() %>%
  add_candidates(reg_res_lr) %>%
  add_candidates(reg_res_sp) %>%
  blend_predictions() %>%
  fit_members()

reg_st
#> ── A stacked ensemble model ─────────────────────────────────────
#> 
#> Out of 11 possible candidate members, the ensemble retained 4.
#> Penalty: 1e-06.
#> Mixture: 1.
#> 
#> The 4 highest weighted members are:
#> # A tibble: 4 × 3
#>   member          type       weight
#>   <chr>           <chr>       <dbl>
#> 1 reg_res_sp_03_1 linear_reg 0.485 
#> 2 reg_res_sp_10_1 linear_reg 0.247 
#> 3 reg_res_lr_1_1  linear_reg 0.129 
#> 4 reg_res_sp_05_1 linear_reg 0.0666

# predict on the tree frogs testing data
predict(reg_st, tree_frogs_reg_test)
#> # A tibble: 143 × 1
#>    .pred
#>    <dbl>
#>  1 119. 
#>  2  81.4
#>  3 102. 
#>  4  35.5
#>  5 119. 
#>  6  50.5
#>  7 122. 
#>  8  82.7
#>  9  50.2
#> 10  75.7
#> # ℹ 133 more rows

# include the predictions from the members
predict(reg_st, tree_frogs_reg_test, members = TRUE)
#> # A tibble: 143 × 5
#>    .pred reg_res_lr_1_1 reg_res_sp_10_1 reg_res_sp_05_1 reg_res_sp_03_1
#>    <dbl>          <dbl>           <dbl>           <dbl>           <dbl>
#>  1 119.           138.            125.            121.            114. 
#>  2  81.4           82.4            84.8            81.8            77.1
#>  3 102.           116.            111.            112.             93.3
#>  4  35.5           35.8            29.7            32.5            29.6
#>  5 119.           111.            115.            115.            127. 
#>  6  50.5           38.8            37.4            36.2            55.3
#>  7 122.           123.            103.            104.            137. 
#>  8  82.7           82.3            78.6            82.0            82.8
#>  9  50.2           38.7            37.3            36.2            54.8
#> 10  75.7           78.8            75.3            76.9            71.8
#> # ℹ 133 more rows

# build and fit a classification model stack
class_st <-
  stacks() %>%
  add_candidates(class_res_nn) %>%
  add_candidates(class_res_rf) %>%
  blend_predictions() %>%
  fit_members()
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
 
class_st
#> ── A stacked ensemble model ─────────────────────────────────────
#> 
#> Out of 21 possible candidate members, the ensemble retained 8.
#> Penalty: 0.01.
#> Mixture: 1.
#> Across the 3 classes, there are an average of 4 coefficients per class.
#> 
#> The 8 highest weighted member classes are:
#> # A tibble: 8 × 4
#>   member                       type           weight class
#>   <chr>                        <chr>           <dbl> <fct>
#> 1 .pred_full_class_res_nn_1_1  mlp         23.3      full 
#> 2 .pred_mid_class_res_nn_1_1   mlp          1.89     mid  
#> 3 .pred_mid_class_res_rf_1_06  rand_forest  1.71     mid  
#> 4 .pred_mid_class_res_rf_1_10  rand_forest  1.17     mid  
#> 5 .pred_full_class_res_rf_1_03 rand_forest  0.407    full 
#> 6 .pred_full_class_res_rf_1_05 rand_forest  0.222    full 
#> 7 .pred_full_class_res_rf_1_01 rand_forest  0.00160  full 
#> 8 .pred_full_class_res_rf_1_02 rand_forest  0.000322 full 

# predict reflex, first as a class, then as
# class probabilities
predict(class_st, tree_frogs_class_test)
#> # A tibble: 303 × 1
#>    .pred_class
#>    <fct>      
#>  1 full       
#>  2 mid        
#>  3 mid        
#>  4 mid        
#>  5 full       
#>  6 full       
#>  7 full       
#>  8 full       
#>  9 full       
#> 10 full       
#> # ℹ 293 more rows
predict(class_st, tree_frogs_class_test, type = "prob")
#> # A tibble: 303 × 3
#>    .pred_full .pred_low .pred_mid
#>         <dbl>     <dbl>     <dbl>
#>  1    0.991     0.00777   0.00132
#>  2    0.00877   0.437     0.554  
#>  3    0.00449   0.260     0.736  
#>  4    0.00912   0.417     0.574  
#>  5    0.990     0.00831   0.00141
#>  6    0.991     0.00781   0.00132
#>  7    0.991     0.00773   0.00131
#>  8    0.991     0.00780   0.00132
#>  9    0.991     0.00777   0.00132
#> 10    0.991     0.00777   0.00132
#> # ℹ 293 more rows

# returning the member predictions as well
predict(
  class_st, 
  tree_frogs_class_test, 
  type = "prob", 
  members = TRUE
)
#> # A tibble: 303 × 24
#>    .pred_full .pred_low .pred_mid .pred_low_class_res_nn_1_1
#>         <dbl>     <dbl>     <dbl>                      <dbl>
#>  1    0.991     0.00777   0.00132                      0.212
#>  2    0.00877   0.437     0.554                        0.481
#>  3    0.00449   0.260     0.736                        0.245
#>  4    0.00912   0.417     0.574                        0.439
#>  5    0.990     0.00831   0.00141                      0.212
#>  6    0.991     0.00781   0.00132                      0.212
#>  7    0.991     0.00773   0.00131                      0.212
#>  8    0.991     0.00780   0.00132                      0.212
#>  9    0.991     0.00777   0.00132                      0.212
#> 10    0.991     0.00777   0.00132                      0.212
#> # ℹ 293 more rows
#> # ℹ 20 more variables: .pred_low_class_res_rf_1_06 <dbl>,
#> #   .pred_low_class_res_rf_1_10 <dbl>, .pred_low_class_res_rf_1_03 <dbl>,
#> #   .pred_low_class_res_rf_1_02 <dbl>, .pred_low_class_res_rf_1_05 <dbl>,
#> #   .pred_low_class_res_rf_1_01 <dbl>, .pred_mid_class_res_nn_1_1 <dbl>,
#> #   .pred_mid_class_res_rf_1_06 <dbl>, .pred_mid_class_res_rf_1_10 <dbl>,
#> #   .pred_mid_class_res_rf_1_03 <dbl>, …