Learn R Programming

iml (version 0.1)

tree.surrogate: Decision tree surrogate model

Description

tree.surrogate() fits a decision tree on the predictions of a machine learning model to make it interpretable.

Usage

tree.surrogate(object, X, sample.size = 100, class = NULL, maxdepth = 2,
  tree.args = NULL, ...)

Arguments

object

The machine learning model. Different types are allowed. Recommended are mlr WrappedModel and caret train objects. The object can also be a function that predicts the outcome given features or anything with an S3 predict function, like an object from class lm.

X

data.frame with the data for the prediction model

sample.size

The number of instances to be sampled from X.

class

In case of classification, class specifies the class for which to predict the probability. By default the multiclass classification is done.

maxdepth

The maximum depth of the tree. Default is 2.

tree.args

A list with further arguments for ctree

...

Further arguments for the prediction method.

Value

A TreeSurrogate object (R6). Its methods and variables can be accessed with the $-operator:

tree

the fitted tree of class party. See also ctree.

maxdepth

the maximal tree depth set by the user.

data()

method to extract the results of the tree. Returns the sampled feature X together with the leaf node information (columns ..node and ..path) and the predicted \(\hat{y}\) for tree and machine learning model (columns starting with ..y.hat).

plot()

method to plot the leaf nodes of the surrogate decision tree. See plot.TreeSurrogate

predict()

method to predict new data with the tree. See also predict.TreeSurrogate

run()

[internal] method to run the interpretability method. Use obj$run(force = TRUE) to force a rerun.

General R6 methods
clone()

[internal] method to clone the R6 object.

initialize()

[internal] method to initialize the R6 object.

Details

A conditional inference tree is fitted on the predicted \(\hat{y}\) from the machine learning model and the data \(X\). The partykit package and function are used to fit the tree. By default a tree of maximum depth of 2 is fitted to improve interpretability.

See Also

predict.TreeSurrogate plot.TreeSurrogate

For the tree implementation ctree

Examples

Run this code
# NOT RUN {
# Fit a Random Forest on the Boston housing data set
library("randomForest")
data("Boston", package  = "MASS")
mod = randomForest(medv ~ ., data = Boston, ntree = 50)

# Fit a decision tree as a surrogate for the whole random forest
dt = tree.surrogate(mod, Boston[-which(names(Boston) == 'medv')], 200)

# Plot the resulting leaf nodes
plot(dt) 

# Use the tree to predict new data
predict(dt, Boston[1:10,])

# Extract the results
dat = dt$data()
head(dat)


# It also works for classification
mod = randomForest(Species ~ ., data = iris, ntree = 50)

# Fit a decision tree as a surrogate for the whole random forest
X = iris[-which(names(iris) == 'Species')]
dt = tree.surrogate(mod, X, 200, predict.args = list(type = 'prob'), maxdepth=2, class=3)

# Plot the resulting leaf nodes
plot(dt) 

# If you want to visualise the tree directly:
plot(dt$tree)

# Use the tree to predict new data
set.seed(42)
iris.sample = X[sample(1:nrow(X), 10),]
predict(dt, iris.sample)
predict(dt, iris.sample, type = 'class')

# Extract the dataset
dat = dt$data()
head(dat)
# }

Run the code above in your browser using DataLab