Learn R Programming

shapr (version 1.0.4)

predict_model: Generate predictions for input data with specified model

Description

Performs prediction of response stats::lm(), stats::glm(), ranger::ranger(), mgcv::gam(), workflows::workflow() (i.e., tidymodels models), and xgboost::xgb.train() with binary or continuous response. See details for more information.

Usage

predict_model(x, newdata, ...)

# S3 method for default predict_model(x, newdata, ...)

# S3 method for ar predict_model(x, newdata, newreg, horizon, ...)

# S3 method for Arima predict_model( x, newdata, newreg, horizon, explain_idx, explain_lags, y, xreg, ... )

# S3 method for forecast_ARIMA predict_model(x, newdata, newreg, horizon, ...)

# S3 method for glm predict_model(x, newdata, ...)

# S3 method for lm predict_model(x, newdata, ...)

# S3 method for gam predict_model(x, newdata, ...)

# S3 method for ranger predict_model(x, newdata, ...)

# S3 method for workflow predict_model(x, newdata, ...)

# S3 method for xgb.Booster predict_model(x, newdata, ...)

Value

Numeric. Vector of size equal to the number of rows in newdata.

Arguments

x

Model object for the model to be explained.

newdata

A data.frame/data.table with the features to predict from.

...

newreg and horizon parameters used in models passed to [explain_forecast()]

horizon

Numeric. The forecast horizon to explain. Passed to the predict_model function.

explain_idx

Numeric vector. The row indices in data and reg denoting points in time to explain.

y

Matrix, data.frame/data.table or a numeric vector. Contains the endogenous variables used to estimate the (conditional) distributions needed to properly estimate the conditional expectations in the Shapley formula including the observations to be explained.

xreg

Matrix, data.frame/data.table or a numeric vector. Contains the exogenous variables used to estimate the (conditional) distributions needed to properly estimate the conditional expectations in the Shapley formula including the observations to be explained. As exogenous variables are used contemporaneously when producing a forecast, this item should contain nrow(y) + horizon rows.

Author

Martin Jullum

Details

The following models are currently supported:

If you have a binary classification model we'll always return the probability prediction for a single class.

If you are explaining a model not supported natively, you need to create the [predict_model()] function yourself, and pass it on to as an argument to [explain()].

For more details on how to explain such non-supported models (i.e. custom models), see the Advanced usage section of the general usage:
From R: vignette("general_usage", package = "shapr")
Web: https://norskregnesentral.github.io/shapr/articles/general_usage.html#explain-custom-models

Examples

Run this code
# Load example data
data("airquality")
airquality <- airquality[complete.cases(airquality), ]
# Split data into test- and training data
x_train <- head(airquality, -3)
x_explain <- tail(airquality, 3)
# Fit a linear model
model <- lm(Ozone ~ Solar.R + Wind + Temp + Month, data = x_train)

# Predicting for a model with a standardized format
predict_model(x = model, newdata = x_explain)

Run the code above in your browser using DataLab