Learn R Programming

iml (version 0.1)

feature.imp: Feature importance

Description

feature.imp() computes feature importances for machine learning models. The importance of a feature is the factor by which the model's prediction error increases when the feature is shuffled.

Usage

feature.imp(object, X, y, class = NULL, loss, method = "shuffle", ...)

Arguments

object

The machine learning model. Different types are allowed. Recommended are mlr WrappedModel and caret train objects. The object can also be a function that predicts the outcome given features or anything with an S3 predict function, like an object from class lm.

X

data.frame with the data for the prediction model

y

The vector or data.frame with the actual target values associated with X.

class

In case of classification, class specifies the class for which to predict the probability. By default the multiclass classification is done.

loss

The loss function. A string (e.g. "ce" for classification or "mse") or a function. See Details.

method

Either 'shuffle' or 'cartesian'. See Details.

...

Further arguments for the prediction method.

Value

An Importance object (R6). Its methods and variables can be accessed with the $-operator:

error.original

The loss of the model before perturbing features.

loss

The loss function. Can also be applied to data: object$loss(actual, predicted)

data()

method to extract the results of the feature importance computation. Returns a data.frame with importance and permutation error measurements per feature.

plot()

method to plot the feature importances. See plot.Importance

run()

[internal] method to run the interpretability method. Use obj$run(force = TRUE) to force a rerun.

General R6 methods
clone()

[internal] method to clone the R6 object.

initialize()

[internal] method to initialize the R6 object.

Details

Read the Interpretable Machine Learning book to learn more about feature importance: https://christophm.github.io/interpretable-ml-book/permutation-feature-importance.html

Two permutation schemes are implemented:

  • shuffle: A simple shuffling of the feature values, yielding n perturbed instances per feature (faster)

  • cartesian: Matching every instance with the feature value of all other instances, yielding n x (n-1) perturbed instances per feature (slow)

The loss function can be either specified via a string, or by handing a function to feature.imp(). Using the string is a shortcut to using loss functions from the Metrics package. See library(help = "Metrics") to get a list of functions. Only use functions that return a single performance value, not a vector. You can also provide a function directly. It has to take the actual value as its first argument and the prediction as its second.

References

Fisher, A., Rudin, C., and Dominici, F. (2018). Model Class Reliance: Variable Importance Measures for any Machine Learning Model Class, from the "Rashomon" Perspective. Retrieved from http://arxiv.org/abs/1801.01489

Examples

Run this code
# NOT RUN {
# We train a tree on the Boston dataset:
if(require("rpart")){
data("Boston", package  = "MASS")
mod = rpart(medv ~ ., data = Boston)

# Compute the individual conditional expectations for the first feature
X = Boston[-which(names(Boston) == 'medv')]
y = Boston$medv

# Compute feature importances as the performance drop in mean absolute error
imp = feature.imp(mod, X, y, loss = 'mae')

# Plot the results directly
plot(imp)


# Since the result is a ggplot object, you can extend it: 
library("ggplot2")
plot(imp) + theme_bw()

# If you want to do your own thing, just extract the data: 
imp.dat = imp$data()
head(imp.dat)
ggplot(imp.dat, aes(x = ..feature, y = importance)) + geom_point() + 
theme_bw()

# feature.imp() also works with multiclass classification. 
# In this case, the importance measurement regards all classes
mod = rpart(Species ~ ., data= iris)
X = iris[-which(names(iris) == 'Species')]
y = iris$Species
# For some models we have to specify additional arguments for the predict function
imp = feature.imp(mod, X, y, loss = 'ce', predict.args = list(type = 'prob'))
plot(imp)
# Here we encounter the special case that the machine learning model perfectly predicts
# The importance becomes infinite
imp$data()

# For multiclass classification models, you can choose to only compute performance for one class. 
# Make sure to adapt y
imp = feature.imp(mod, X, y == 'virginica', class = 3, loss = 'ce', 
    predict.args = list(type = 'prob'))
plot(imp)
}
# }

Run the code above in your browser using DataLab