Add predictions to a data frame

add_predictions(data, model, var = "pred", type = NULL)

spread_predictions(data, ..., type = NULL)

gather_predictions(data, ..., .pred = "pred", .model = "model", type = NULL)

Arguments

data

A data frame used to generate the predictions.

model

add_predictions takes a single model;

var

The name of the output column, default value is pred

type

Prediction type, passed on to stats::predict(). Consult predict() documentation for given model to determine valid values.

...

gather_predictions and spread_predictions take multiple models. The name will be taken from either the argument name of the name of the model.

.pred, .model

The variable names used by gather_predictions.

Value

A data frame. add_prediction adds a single new column, with default name pred, to the input data. spread_predictions adds one column for each model. gather_predictions adds two columns .model and .pred, and repeats the input rows for each model.

Examples

df <- tibble::tibble(
  x = sort(runif(100)),
  y = 5 * x + 0.5 * x ^ 2 + 3 + rnorm(length(x))
)
plot(df)


m1 <- lm(y ~ x, data = df)
grid <- data.frame(x = seq(0, 1, length = 10))
grid %>% add_predictions(m1)
#>            x     pred
#> 1  0.0000000 3.036268
#> 2  0.1111111 3.649103
#> 3  0.2222222 4.261939
#> 4  0.3333333 4.874774
#> 5  0.4444444 5.487609
#> 6  0.5555556 6.100445
#> 7  0.6666667 6.713280
#> 8  0.7777778 7.326116
#> 9  0.8888889 7.938951
#> 10 1.0000000 8.551786

m2 <- lm(y ~ poly(x, 2), data = df)
grid %>% spread_predictions(m1, m2)
#>            x       m1       m2
#> 1  0.0000000 3.036268 3.211703
#> 2  0.1111111 3.649103 3.721594
#> 3  0.2222222 4.261939 4.257809
#> 4  0.3333333 4.874774 4.820350
#> 5  0.4444444 5.487609 5.409215
#> 6  0.5555556 6.100445 6.024405
#> 7  0.6666667 6.713280 6.665920
#> 8  0.7777778 7.326116 7.333759
#> 9  0.8888889 7.938951 8.027924
#> 10 1.0000000 8.551786 8.748413
grid %>% gather_predictions(m1, m2)
#>    model         x     pred
#> 1     m1 0.0000000 3.036268
#> 2     m1 0.1111111 3.649103
#> 3     m1 0.2222222 4.261939
#> 4     m1 0.3333333 4.874774
#> 5     m1 0.4444444 5.487609
#> 6     m1 0.5555556 6.100445
#> 7     m1 0.6666667 6.713280
#> 8     m1 0.7777778 7.326116
#> 9     m1 0.8888889 7.938951
#> 10    m1 1.0000000 8.551786
#> 11    m2 0.0000000 3.211703
#> 12    m2 0.1111111 3.721594
#> 13    m2 0.2222222 4.257809
#> 14    m2 0.3333333 4.820350
#> 15    m2 0.4444444 5.409215
#> 16    m2 0.5555556 6.024405
#> 17    m2 0.6666667 6.665920
#> 18    m2 0.7777778 7.333759
#> 19    m2 0.8888889 8.027924
#> 20    m2 1.0000000 8.748413