Learn R Programming

mlsurvlrnrs (version 0.0.8)

LearnerSurvXgboostAft: R6 Class to construct a Xgboost survival learner for accelerated failure time models

Description

The LearnerSurvXgboostAft class is the interface to accelerated failure time models with the xgboost R package for use with the mlexperiments package.

Arguments

Super classes

mlexperiments::MLLearnerBase -> mllrnrs::LearnerXgboost -> LearnerSurvXgboostAft

Methods

Inherited methods


Method new()

Create a new LearnerSurvXgboostAft object.

Usage

LearnerSurvXgboostAft$new(metric_optimization_higher_better)

Arguments

metric_optimization_higher_better

A logical. Defines the direction of the optimization metric used throughout the hyperparameter optimization.

Returns

A new LearnerSurvXgboostAft R6 object.

Examples

if (requireNamespace("xgboost", quietly = TRUE)) {
  LearnerSurvXgboostAft$new(metric_optimization_higher_better = FALSE)
}


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerSurvXgboostAft$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Details

Optimization metric: needs to be specified with the learner parameter eval_metric. Can be used with

  • mlexperiments::MLTuneParameters

  • mlexperiments::MLCrossValidation

  • mlexperiments::MLNestedCV

Also see the official xgboost documentation on aft models: https://xgboost.readthedocs.io/en/stable/tutorials/aft_survival_analysis.html

See Also

xgboost::xgb.train(), xgboost::xgb.cv()

Examples

Run this code
# execution time >2.5 sec
if (requireNamespace("survival", quietly = TRUE) &&
requireNamespace("glmnet", quietly = TRUE) &&
requireNamespace("xgboost", quietly = TRUE) &&
requireNamespace("splitTools", quietly = TRUE)) {

  # survival analysis
  Sys.setenv("OMP_THREAD_LIMIT" = 2)

  dataset <- survival::colon |>
    data.table::as.data.table() |>
    na.omit()
  dataset <- dataset[get("etype") == 2, ]

  seed <- 123
  surv_cols <- c("status", "time", "rx")

  feature_cols <- colnames(dataset)[3:(ncol(dataset) - 1)]

  param_list_xgboost <- expand.grid(
    objective = "survival:aft",
    eval_metric = "aft-nloglik",
    subsample = seq(0.6, 1, .2),
    colsample_bytree = seq(0.6, 1, .2),
    min_child_weight = seq(1, 5, 4),
    learning_rate = c(0.1, 0.2),
    max_depth = seq(1, 5, 4)
  )
  ncores <- 2L

  split_vector <- splitTools::multi_strata(
    df = dataset[, .SD, .SDcols = surv_cols],
    strategy = "kmeans",
    k = 4
  )

  train_x <- model.matrix(
    ~ -1 + .,
    dataset[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
  )
  train_y <- survival::Surv(
    event = (dataset[, get("status")] |>
               as.character() |>
               as.integer()),
    time = dataset[, get("time")],
    type = "right"
  )

  fold_list <- splitTools::create_folds(
    y = split_vector,
    k = 3,
    type = "stratified",
    seed = seed
  )

  surv_xgboost_aft_optimizer <- mlexperiments::MLCrossValidation$new(
    learner = LearnerSurvXgboostAft$new(
      metric_optimization_higher_better = FALSE
    ),
    fold_list = fold_list,
    ncores = ncores,
    seed = seed
  )
  surv_xgboost_aft_optimizer$learner_args <- c(as.list(
    param_list_xgboost[1, ]),
    nrounds = 45L
  )
  surv_xgboost_aft_optimizer$performance_metric <- c_index

  # set data
  surv_xgboost_aft_optimizer$set_data(
    x = train_x,
    y = train_y
  )

  surv_xgboost_aft_optimizer$execute()
}


## ------------------------------------------------
## Method `LearnerSurvXgboostAft$new`
## ------------------------------------------------

if (requireNamespace("xgboost", quietly = TRUE)) {
  LearnerSurvXgboostAft$new(metric_optimization_higher_better = FALSE)
}

Run the code above in your browser using DataLab