Learn R Programming

survex: Explainable Machine Learning in Survival Analysis

Overview

Survival analysis is a task dealing with time-to-event prediction. Aside from the well-understood models like CPH, many more complex models have recently emerged, but most lack interpretability. Due to a functional type of prediction, either in the form of survival function or cumulative hazard function, standard model-agnostic explanations cannot be applied directly.

The survex package provides model-agnostic explanations for machine learning survival models. It is based on the DALEX package. If you're unfamiliar with explainable machine learning, consider referring to the Explanatory Model Analysis book -- most of the methods included in survex extend these described in EMA and implemented in DALEX but to models with functional output.

The main explain() function uses a model and data to create a standardized explainer object, which is further used as an interface for calculating predictions. We automate creating explainers from the following packages: mlr3proba, censored, ranger, randomForestSRC, and survival. Raise an Issue on GitHub if you find models from other packages that we can incorporate into the explain() interface.

Note that an explainer can be created for any survival model, using the explain_survival() function by passing model, data, y, and predict_survival_function arguments.

Installation

The package is available on CRAN:

install.packages("survex")

The latest development version can be installed from GitHub using devtools::install_github():

devtools::install_github("https://github.com/ModelOriented/survex")

Simple demo

library("survex")
library("survival")
library("ranger")

# create a model
model <- ranger(Surv(time, status) ~ ., data = veteran)

# create an explainer
explainer <- explain(model, 
                     data = veteran[, -c(3, 4)],
                     y = Surv(veteran$time, veteran$status))

# evaluate the model
model_performance(explainer)

# visualize permutation-based feature importance
plot(model_parts(explainer))

# explain one prediction with SurvSHAP(t)
plot(predict_parts(explainer, veteran[1, -c(3, 4)]))

Functionalities and roadmap

Existing functionalities:

  • unified prediction interface using the explainer object - predict()
  • calculation of performance metrics (Brier Score, Time-dependent C/D AUC, metrics from mlr3proba) - model_performance()
  • calculation of feature importance (Permutation Feature Importance - PFI) - model_parts()
  • calculation of partial dependence (Partial Dependence Profiles - PDP, Accumulated Local Effects - ALE) - model_profile()
  • calculation of 2-dimensional partial dependence (2D PDP, 2D ALE) - model_profile_2d()
  • calculation of local feature attributions (SurvSHAP(t), SurvLIME) - predict_parts()
  • calculation of local ceteris paribus explanations (Ceteris Paribus profiles - CP/ Individual Conditional Expectations - ICE) - predict_profile()
  • calculation of global feature attributions using SurvSHAP(t) - model_survshap()

Currently in develompment:

  • ...

Future plans:

  • ... (raise an Issue on GitHub if you have any suggestions)

Usage

Citation

If you use survex, please cite our preprint:

M. Spytek, M. Krzyziński, S. H. Langbein, H. Baniecki, M. N. Wright, P. Biecek. survex: an R package for explaining machine learning survival models. arXiv preprint arXiv:2308.16113, 2023.

@article{spytek2023survex,
    title   = {{survex: an R package for explaining machine learning survival models}},
    author  = {Mikołaj Spytek and Mateusz Krzyziński and Sophie Hanna Langbein and
               Hubert Baniecki and Marvin N. Wright and Przemysław Biecek},
    journal = {arXiv preprint arXiv:2308.16113},
    year    = {2023}
}

Applications of survex

Related work

Copy Link

Version

Install

install.packages('survex')

Monthly Downloads

377

Version

1.2.0

License

GPL (>= 3)

Maintainer

Mikołaj Spytek

Last Published

October 24th, 2023

Functions in survex (1.2.0)

model_diagnostics

Dataset Level Model Diagnostics
model_survshap

Global SHAP Values
model_profile_2d

Dataset Level 2-Dimensional Variable Profile for Survival Models
loss_one_minus_cd_auc

Calculate Cumulative/Dynamic AUC loss
loss_one_minus_c_index

Calculate the Concordance index loss
model_profile

Dataset Level Variable Profile as Partial Dependence Explanations for Survival Models
model_performance

Dataset Level Performance Measures
plot.aggregated_surv_shap

Plot Aggregated SurvSHAP(t) Explanations for Survival Models
plot.surv_lime

Plot SurvLIME Explanations for Survival Models
plot.model_diagnostics_survival

Plot Model Diagnostics for Survival Models
plot.predict_profile_survival

Plot Predict Profile for Survival Models
plot.model_parts_survival

Plot Model Parts for Survival Models
plot.surv_feature_importance

Plot Permutational Feature Importance for Survival Models
plot.surv_model_performance

Plot Model Performance Metrics for Survival Models
plot.model_profile_2d_survival

Plot 2-Dimensional Model Profile for Survival Models
surv_lime

Helper functions for predict_parts.R
plot.model_performance_survival

Plot Model Performance for Survival Models
surv_ceteris_paribus

Helper functions for predict_profile.R
surv_feature_importance

Helper functions for model_parts.R
surv_integrated_feature_importance

Helper functions for model_parts.R
predict_profile

Instance Level Profile as Ceteris Paribus for Survival Models
predict_parts

Instance Level Parts of Survival Model Predictions
predict.surv_explainer

Model Predictions for Survival Models
risk_from_chf

Generate Risk Prediction based on the Survival Function
plot.surv_model_performance_rocs

Plot ROC Curves for Survival Models
surv_shap

Helper functions for predict_parts.R
survival_to_cumulative_hazard

Transform Survival to Cumulative Hazard
plot.predict_parts_survival

Plot Predict Parts for Survival Models
plot.model_profile_survival

Plot Model Profile for Survival Models
set_theme_survex

Default Theme for survex plots
transform_to_stepfunction

Transform Fixed Point Prediction into a Stepfunction
surv_model_info

Extract additional information from the model
surv_model_performance

Helper functions for model_performance.R
plot.surv_shap

Plot SurvSHAP(t) Explanations for Survival Models
cd_auc

Calculate Cumulative/Dynamic AUC
cumulative_hazard_to_survival

Transform Cumulative Hazard to Survival
explain_survival

A model-agnostic explainer for survival models
integrated_brier_score

Calculate integrated Brier score
brier_score

Calculate Brier score
integrated_cd_auc

Calculate integrated C/D AUC
model_parts

Dataset Level Variable Importance for Survival Models
c_index

Compute the Harrell's Concordance index
extract_predict_survshap

Extract Local SurvSHAP(t) from Global SurvSHAP(t)
loss_adapt_mlr3proba

Adapt mlr3proba measures for use with survex
loss_integrate

Calculate integrated metrics based on time-dependent metrics.
loss_one_minus_integrated_cd_auc

Calculate integrated C/D AUC loss