Learn R Programming

xgrove (version 0.1-15)

sgtree: Surrogate trees

Description

Compute surrogate trees of different depth to explain predictive machine learning model and analyze complexity vs. explanatory power.

Usage

sgtree(model, data, maxdeps = 1:8, cparam = 0, pfun = NULL, ...)

Value

List of the results:

explanation

Matrix containing tree sizes, rules, explainability \({\Upsilon}\) and the correlation between the predictions of the explanation and the true model.

rules

List of rules for each tree.

model

List of the rpart models.

Arguments

model

A model with corresponding predict function that returns numeric values.

data

Data that must not (!) contain the target variable.

maxdeps

Sequence of integers: Maximum depth of the trees.

cparam

Complexity parameter for growing the trees.

pfun

Optional predict function function(model, data) returning a real number. Default is the predict() method of the model.

...

Further arguments to be passed to rpart.control or the predict() method of the model.

Details

A surrogate grove is trained via gradient boosting using rpart on data with the predictions of using of the model as target variable. Note that data must not contain the original target variable!

References

  • Szepannek, G. and Laabs, B.H. (2023): Can’t see the forest for the trees -- analyzing groves to explain random forests, Behaviormetrika, submitted.

  • Szepannek, G. and Luebke, K.(2023): How much do we see? On the explainability of partial dependence plots for credit risk scoring, Argumenta Oeconomica 50, DOI: 10.15611/aoe.2023.1.07.

Examples

Run this code
library(randomForest)
library(pdp)
data(boston)
set.seed(42)
rf    <- randomForest(cmedv ~ ., data = boston)
data  <- boston[,-3] # remove target variable
maxds <- 1:7
st    <- sgtree(rf, data, maxds)
st
# rules for tree of depth 3
st$rules[["3"]]
# plot tree of depth 3
rpart.plot::rpart.plot(st$model[["3"]])

Run the code above in your browser using DataLab