Apply a model to create different types of predictions.
predict()
can be used for all types of models and uses the
"type" argument for more specificity.
# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)
predict_raw(object, ...)
With the exception of type = "raw"
, the result of
predict.model_fit()
is a tibble
has as many rows as there are rows in new_data
has standardized column names, see below:
For type = "numeric"
, the tibble has a .pred
column for a single
outcome and .pred_Yname
columns for a multivariate outcome.
For type = "class"
, the tibble has a .pred_class
column.
For type = "prob"
, the tibble has .pred_classlevel
columns.
For type = "conf_int"
and type = "pred_int"
, the tibble has
.pred_lower
and .pred_upper
columns with an attribute for
the confidence level. In the case where intervals can be
produces for class probabilities (or other non-scalar outputs),
the columns are named .pred_lower_classlevel
and so on.
For type = "quantile"
, the tibble has a .pred
column, which is
a list-column. Each list element contains a tibble with columns
.pred
and .quantile
(and perhaps other columns).
For type = "time"
, the tibble has a .pred_time
column.
For type = "survival"
, the tibble has a .pred
column, which is
a list-column. Each list element contains a tibble with columns
.eval_time
and .pred_survival
(and perhaps other columns).
For type = "hazard"
, the tibble has a .pred
column, which is
a list-column. Each list element contains a tibble with columns
.eval_time
and .pred_hazard
(and perhaps other columns).
Using type = "raw"
with predict.model_fit()
will return
the unadulterated results of the prediction function.
In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but type-specific prediction functions.
When the model fit failed and the error was captured, the
predict()
function will return the same structure as above but
filled with missing values. This does not currently work for
multivariate models.
An object of class model_fit
.
A rectangular data object, such as a data frame.
A single character value or NULL
. Possible values
are "numeric"
, "class"
, "prob"
, "conf_int"
, "pred_int"
,
"quantile"
, "time"
, "hazard"
, "survival"
, or "raw"
. When NULL
,
predict()
will choose an appropriate value based on the model's mode.
A list of optional arguments to the underlying
predict function that will be used when type = "raw"
. The
list should not include options for the model object or the
new data being predicted.
Additional parsnip
-related options, depending on the
value of type
. Arguments to the underlying model's prediction
function cannot be passed here (use the opts
argument instead).
Possible arguments are:
interval
: for type
equal to "survival"
or "quantile"
, should
interval estimates be added, if available? Options are "none"
and "confidence"
.
level
: for type
equal to "conf_int"
, "pred_int"
, or "survival"
,
this is the parameter for the tail area of the intervals
(e.g. confidence level for confidence intervals).
Default value is 0.95
.
std_error
: for type
equal to "conf_int"
or "pred_int"
, add
the standard error of fit or prediction (on the scale of the
linear predictors). Default value is FALSE
.
quantile
: for type
equal to quantile
, the quantiles of the
distribution. Default is (1:9)/10
.
eval_time
: for type
equal to "survival"
or "hazard"
, the
time points at which the survival probability or hazard is estimated.
For type = NULL
, predict()
uses
type = "numeric"
for regression models,
type = "class"
for classification, and
type = "time"
for censored regression.
When using type = "conf_int"
and type = "pred_int"
, the options
level
and std_error
can be used. The latter is a logical for an
extra column of standard error values (if available).
For censored regression, a numeric vector for eval_time
is required when
survival or hazard probabilities are requested. The time values are required
to be unique, finite, non-missing, and non-negative. The predict()
functions will adjust the values to fit this specification by removing
offending points (with a warning).
Also, when type = "linear_pred"
, censored regression models will by default
be formatted such that the linear predictor increases with time. This may
have the opposite sign as what the underlying model's predict()
method
produces. Set increasing = FALSE
to suppress this behavior.
library(dplyr)
lm_model <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))
pred_cars <-
mtcars %>%
dplyr::slice(1:10) %>%
dplyr::select(-mpg)
predict(lm_model, pred_cars)
predict(
lm_model,
pred_cars,
type = "conf_int",
level = 0.90
)
predict(
lm_model,
pred_cars,
type = "raw",
opts = list(type = "terms")
)
Run the code above in your browser using DataLab