# \donttest{
if (requireNamespace("xgboost", quietly = TRUE) &&
requireNamespace("torch", quietly = TRUE) &&
torch::torch_is_installed()) {
data("airquality")
data <- data.table::as.data.table(airquality)
data <- data[complete.cases(data), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"
ind_x_explain <- 1:6
x_train <- data[-ind_x_explain, ..x_var]
y_train <- data[-ind_x_explain, get(y_var)]
x_explain <- data[ind_x_explain, ..x_var]
# Fitting a basic xgboost model to the training data
model <- xgboost::xgboost(
data = as.matrix(x_train),
label = y_train,
nround = 100,
verbose = FALSE
)
# Specifying the phi_0, i.e. the expected prediction without any features
p0 <- mean(y_train)
# Train vaeac with and without paired sampling
explanation_paired <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "vaeac",
phi0 = p0,
n_MC_samples = 1, # As we are only interested in the training of the vaeac
vaeac.epochs = 10, # Should be higher in applications.
vaeac.n_vaeacs_initialize = 1,
vaeac.width = 16,
vaeac.depth = 2,
vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE)
)
explanation_regular <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = "vaeac",
phi0 = p0,
n_MC_samples = 1, # As we are only interested in the training of the vaeac
vaeac.epochs = 10, # Should be higher in applications.
vaeac.width = 16,
vaeac.depth = 2,
vaeac.n_vaeacs_initialize = 1,
vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE)
)
# Collect the explanation objects in an named list
explanation_list <- list(
"Regular sampling" = explanation_regular,
"Paired sampling" = explanation_paired
)
# Call the function with the named list, will use the provided names
plot_vaeac_eval_crit(explanation_list = explanation_list)
# The function also works if we have only one method,
# but then one should only look at the method plot.
plot_vaeac_eval_crit(
explanation_list = explanation_list[2],
plot_type = "method"
)
# Can alter the plot
plot_vaeac_eval_crit(
explanation_list = explanation_list,
plot_from_nth_epoch = 2,
plot_every_nth_epoch = 2,
facet_wrap_scales = "free"
)
# If we only want the VLB
plot_vaeac_eval_crit(
explanation_list = explanation_list,
criteria = "VLB",
plot_type = "criterion"
)
# If we want only want the criterion version
tmp_fig_criterion <-
plot_vaeac_eval_crit(explanation_list = explanation_list, plot_type = "criterion")
# Since tmp_fig_criterion is a ggplot2 object, we can alter it
# by, e.g,. adding points or smooths with se bands
tmp_fig_criterion + ggplot2::geom_point(shape = "circle", size = 1, ggplot2::aes(col = Method))
tmp_fig_criterion$layers[[1]] <- NULL
tmp_fig_criterion + ggplot2::geom_smooth(method = "loess", formula = y ~ x, se = TRUE) +
ggplot2::scale_color_brewer(palette = "Set1") +
ggplot2::theme_minimal()
}
# }
Run the code above in your browser using DataLab