Learn R Programming

iml (version 0.9.0)

FeatureImp: Feature importance

Description

FeatureImp computes feature importances for prediction models. The importance is measured as the factor by which the model's prediction error increases when the feature is shuffled.

Format

R6Class object.

Usage

imp = FeatureImp$new(predictor, loss, compare = "ratio", n.repetitions = 5)

imp$plot() imp$results print(imp)

Arguments

For FeatureImp$new():

predictor:

(Predictor) The object (created with Predictor$new()) holding the machine learning model and the data.

loss:

(`character(1)` | function) The loss function. Either the name of a loss (e.g. "ce" for classification or "mse") or a function. See Details for allowed losses.

compare:

(`character(1)`) Either "ratio" or "difference". Should importance be measured as the difference or as the ratio of original model error and model error after permutation? Ratio: error.permutation/error.orig, Difference: error.permutation - error.orig

n.repetitions:

`numeric(1)` How often should the shuffling of the feature be repeated? The higher the number of repetitions the more stable and accurate the results become.

parallel:

`logical(1)` Should the method be executed in parallel? If TRUE, requires a cluster to be registered, see ?foreach::foreach.

Fields

original.error:

(`numeric(1)`) The loss of the model before perturbing features.

predictor:

(Predictor) The prediction model that was analysed.

compare:

(`character(1)`) Either "ratio" or "difference", depending on whether the importance was calculated as difference between original model error and model error after permutation or as ratio.

results:

(data.frame) data.frame with the results of the feature importance computation. One row per feature with the following columns: importance.05 (5 importance (median importance), importance.95 (95 is also visualized as a bar in the plots, the median importance over the repetitions as a point.

Methods

loss(actual,predicted)

The loss function. Can also be applied to data: object$loss(actual, predicted)

plot()

method to plot the feature importances. See plot.FeatureImp

clone()

[internal] method to clone the R6 object.

initialize()

[internal] method to initialize the R6 object.

Details

To compute the feature importance for a single feature, the model prediction loss (error) is measured before and after shuffling the values of the feature. By shuffling the feature values, the association between the outcome and the feature is destroyed. The larger the increase in prediction error, the more important the feature was. The shuffling is repeated to get more accurate results, since the permutation feature importance tends to be quite instable. Read the Interpretable Machine Learning book to learn about feature importance in detail: https://christophm.github.io/interpretable-ml-book/feature-importance.html

The loss function can be either specified via a string, or by handing a function to FeatureImp(). If you want to use your own loss function it should have this signature: function(actual, predicted). Using the string is a shortcut to using loss functions from the Metrics package. Only use functions that return a single performance value, not a vector. Allowed losses are: "ce", "f1", "logLoss", "mae", "mse", "rmse", "mape", "mdae", "msle", "percent_bias", "rae", "rmse", "rmsle", "rse", "rrse", "smape" See library(help = "Metrics") to get a list of functions.

References

Fisher, A., Rudin, C., and Dominici, F. (2018). Model Class Reliance: Variable Importance Measures for any Machine Learning Model Class, from the "Rashomon" Perspective. Retrieved from http://arxiv.org/abs/1801.01489

Examples

Run this code
# NOT RUN {
if (require("rpart")) {
# We train a tree on the Boston dataset:
data("Boston", package  = "MASS")
tree = rpart(medv ~ ., data = Boston)
y = Boston$medv
X = Boston[-which(names(Boston) == "medv")]
mod = Predictor$new(tree, data = X, y = y)


# Compute feature importances as the performance drop in mean absolute error
imp = FeatureImp$new(mod, loss = "mae")

# Plot the results directly
plot(imp)


# Since the result is a ggplot object, you can extend it: 
if (require("ggplot2")) {
  plot(imp) + theme_bw()
  # If you want to do your own thing, just extract the data: 
  imp.dat = imp$results
  head(imp.dat)
  ggplot(imp.dat, aes(x = feature, y = importance)) + geom_point() + 
  theme_bw()
}

# We can also look at the difference in model error instead of the ratio
imp = FeatureImp$new(mod, loss = "mae", compare = "difference")

# Plot the results directly
plot(imp)


# FeatureImp also works with multiclass classification. 
# In this case, the importance measurement regards all classes
tree = rpart(Species ~ ., data= iris)
X = iris[-which(names(iris) == "Species")]
y = iris$Species
mod = Predictor$new(tree, data = X, y = y, type = "prob") 

# For some models we have to specify additional arguments for the predict function
imp = FeatureImp$new(mod, loss = "ce")
plot(imp)

# For multiclass classification models, you can choose to only compute performance for one class. 
# Make sure to adapt y
mod = Predictor$new(tree, data = X, y = y == "virginica", 
 type = "prob", class = "virginica") 
imp = FeatureImp$new(mod, loss = "ce")
plot(imp)
}
# }

Run the code above in your browser using DataLab