# \donttest{
if (requireNamespace("party", quietly = TRUE)) {
data("airquality")
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- head(airquality, -50)
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "empirical",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace(c("ggplot2", "ggbeeswarm"), quietly = TRUE)) {
# The default plotting option is a bar plot of the Shapley values
# We draw bar plots for the first 4 observations
plot(x, index_x_explain = 1:4)
# We can also make waterfall plots
plot(x, plot_type = "waterfall", index_x_explain = 1:4)
# And only showing the 2 features with largest contribution
plot(x, plot_type = "waterfall", index_x_explain = 1:4, top_k_features = 2)
# Or scatter plots showing the distribution of the shapley values and feature values
plot(x, plot_type = "scatter")
# And only for a specific feature
plot(x, plot_type = "scatter", scatter_features = "Temp")
# Or a beeswarm plot summarising the Shapley values and feature values for all features
plot(x, plot_type = "beeswarm")
plot(x, plot_type = "beeswarm", col = c("red", "black")) # we can change colors
# Additional arguments can be passed to ggbeeswarm::geom_beeswarm() using the '...' argument.
# For instance, sometimes the beeswarm plots overlap too much.
# This can be fixed with the 'corral="wrap" argument.
# See ?ggbeeswarm::geom_beeswarm for more information.
plot(x, plot_type = "beeswarm", corral = "wrap")
}
# Example of scatter and beeswarm plot with factor variables
airquality$Month_factor <- as.factor(month.abb[airquality$Month])
airquality <- airquality[complete.cases(airquality), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month_factor")
y_var <- "Ozone"
# Split data into test- and training data
data_train <- airquality
data_explain <- tail(airquality, 50)
x_train <- data_train[, x_var]
x_explain <- data_explain[, x_var]
# Fit a linear model
lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + ")))
model <- lm(lm_formula, data = data_train)
# Explain predictions
p <- mean(data_train[, y_var])
# Empirical approach
x <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "ctree",
phi0 = p,
n_MC_samples = 1e2
)
if (requireNamespace(c("ggplot2", "ggbeeswarm"), quietly = TRUE)) {
plot(x, plot_type = "scatter")
plot(x, plot_type = "beeswarm")
}
}
# }
Run the code above in your browser using DataLab