mlr3tuning (version 0.5.0)

AutoTuner: AutoTuner

Description

The AutoTuner is a mlr3::Learner which wraps another mlr3::Learner and performs the following steps during $train():

  1. The hyperparameters of the wrapped (inner) learner are trained on the training data via resampling. The tuning can be specified by providing a Tuner, a bbotk::Terminator, a search space as paradox::ParamSet, a mlr3::Resampling and a mlr3::Measure.

  2. The best found hyperparameter configuration is set as hyperparameters for the wrapped (inner) learner.

  3. A final model is fit on the complete training data using the now parametrized wrapped learner.

During $predict() the AutoTuner just calls the predict method of the wrapped (inner) learner.

Note that this approach allows to perform nested resampling by passing an AutoTuner object to mlr3::resample() or mlr3::benchmark(). To access the inner resampling results, set store_tuning_instance = TRUE and execute mlr3::resample() or mlr3::benchmark() with store_models = TRUE (see examples).

Arguments

Super class

mlr3::Learner -> AutoTuner

Public fields

instance_args

(list()) All arguments from construction to create the TuningInstanceSingleCrit.

tuner

(Tuner).

Active bindings

archive

ArchiveTuning Archive of the TuningInstanceSingleCrit.

learner

(mlr3::Learner) Trained learner

tuning_instance

(TuningInstanceSingleCrit) Internally created tuning instance with all intermediate results.

tuning_result

(named list()) Short-cut to result from TuningInstanceSingleCrit.

predict_type

(character(1)) Stores the currently active predict type, e.g. "response". Must be an element of $predict_types.

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

AutoTuner$new(
  learner,
  resampling,
  measure,
  terminator,
  tuner,
  search_space = NULL,
  store_tuning_instance = TRUE,
  store_benchmark_result = TRUE,
  store_models = FALSE,
  check_values = FALSE
)

Arguments

learner

(mlr3::Learner) Learner to tune, see TuningInstanceSingleCrit.

resampling

(mlr3::Resampling) Resampling strategy during tuning, see TuningInstanceSingleCrit. This mlr3::Resampling is meant to be the inner resampling, operating on the training set of an arbitrary outer resampling. For this reason it is not feasible to pass an instantiated mlr3::Resampling here.

measure

(mlr3::Measure) Performance measure to optimize.

terminator

(bbotk::Terminator) When to stop tuning, see TuningInstanceSingleCrit.

tuner

(Tuner) Tuning algorithm to run.

search_space

(paradox::ParamSet) Hyperparameter search space, see TuningInstanceSingleCrit.

store_tuning_instance

(logical(1)) If TRUE (default), stores the internally created TuningInstanceSingleCrit with all intermediate results in slot $tuning_instance.

store_benchmark_result

(logical(1)) Store benchmark result in archive?

store_models

(logical(1)) Store models in benchmark result?

check_values

(logical(1)) Should parameters before the evaluation and the results be checked for validity?

Method clone()

The objects of this class are cloneable with this method.

Usage

AutoTuner$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

Run this code
# NOT RUN {
library(mlr3)
library(paradox)

task = tsk("iris")
search_space = ParamSet$new(
  params = list(ParamDbl$new("cp", lower = 0.001, upper = 0.1))
)

at = AutoTuner$new(
  learner = lrn("classif.rpart"),
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = trm("evals", n_evals = 5),
  tuner = tnr("grid_search"),
  search_space = search_space,
  store_tuning_instance = TRUE)

at$train(task)
at$model
at$learner

# Nested resampling
at = AutoTuner$new(
  learner = lrn("classif.rpart"),
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = trm("evals", n_evals = 5),
  tuner = tnr("grid_search"),
  search_space = search_space,
  store_tuning_instance = TRUE)

resampling_outer = rsmp("cv", folds = 2)
rr = resample(task, at, resampling_outer, store_models = TRUE)

# Aggregate performance of outer results
rr$aggregate()

# Retrieve inner tuning results.
as.data.table(rr)$learner[[1]]$tuning_result
# }

Run the code above in your browser using DataCamp Workspace