Learn R Programming

mlr3hyperband (version 0.2.0)

mlr_optimizers_hyperband: Optimizer using the Hyperband algorithm

Description

OptimizerHyperband class that implements hyperband optimization. Hyperband is a budget oriented-procedure, weeding out suboptimal performing configurations early in a sequential training process, increasing optimization efficiency as a consequence.

For this, several brackets are constructed with an associated set of configurations for each bracket. Each bracket as several stages. Different brackets are initialized with different amounts of configurations and different budget sizes. To get an idea of how the bracket layout looks like for a given argument set, please have a look in the details.

To identify the budget for evaluating hyperband, the user has to specify explicitly which parameter of the objective function influences the budget by tagging a single parameter in the paradox::ParamSet with "budget".

Naturally, hyperband terminates once all of its brackets are evaluated, so a bbotk::Terminator in the OptimInstanceSingleCrit | OptimInstanceMultiCrit acts as an upper bound and should be only set to a low value if one is unsure of how long hyperband will take to finish under the given settings.

Arguments

Dictionary

This Optimizer can be instantiated via the dictionary mlr_optimizers or with the associated sugar function opt():

mlr_optimizers$get("hyperband")
opt("hyperband")

Parameters

eta

numeric(1) Fraction parameter of the successive halving algorithm: With every step the configuration budget is increased by a factor of eta and only the best 1/eta configurations are used for the next stage. Non-integer values are supported, but eta is not allowed to be less or equal 1.

sampler

paradox::Sampler Object defining how the samples of the parameter space should be drawn during the initialization of each bracket. The default is uniform sampling.

Archive

The bbotk::Archive holds the following additional columns that are specific to the hyperband tuner:

  • bracket (integer(1)) The console logs about the bracket index are actually not matching with the original hyperband algorithm, which counts down the brackets and stops after evaluating bracket 0. The true bracket indices are given in this column.

  • bracket_stage (integer(1)) The bracket stage of each bracket. Hyperband starts counting at 0.

  • budget_scaled (numeric(1)) The intermediate budget in each bracket stage calculated by hyperband. Because hyperband is originally only considered for budgets starting at 1, some rescaling is done to allow budgets starting at different values. For this, budgets are internally divided by the lower budget bound to get a lower budget of 1. Before the objective function receives its budgets for evaluation, the budget is transformed back to match the original scale again.

  • budget_real (numeric(1)) The real budget values the objective function uses for evaluation after hyperband calculated its scaled budget.

  • n_configs (integer(1)) The amount of evaluated configurations in each stage. These correspond to the r_i in the original paper.

Custom sampler

Hyperband supports custom paradox::Sampler object for initial configurations in each bracket. A custom sampler may look like this (the full example is given in the examples section):

# - beta distribution with alpha = 2 and beta = 5
# - categorical distribution with custom probabilities
sampler = SamplerJointIndep$new(list(
  Sampler1DRfun$new(params[[2]], function(n) rbeta(n, 2, 5)),
  Sampler1DCateg$new(params[[3]], prob = c(0.2, 0.3, 0.5))
))

Runtime

The calculation of each bracket currently assumes a linear runtime in the chosen budget parameter is always given. Hyperband is designed so each bracket requires approximately the same runtime as the sum of the budget over all configurations in each bracket is roughly the same. This will not hold true once the scaling in the budget parameter is not linear anymore, even though the sum of the budgets in each bracket remains the same. A possible adaption would be to introduce a trafo, like it is shown in the examples section.

Progress Bars

$optimize() supports progress bars via the package progressr combined with a Terminator. Simply wrap the function in progressr::with_progress() to enable them. We recommend to use package progress as backend; enable with progressr::handlers("progress").

Parallelization

In order to support general termination criteria and parallelization, we evaluate points in a batch-fashion of size batch_size. The points of one stage in a bracket are evaluated in one batch. Parallelization is supported via package future (see mlr3::benchmark()'s section on parallelization for more details).

Logging

Hyperband uses a logger (as implemented in lgr) from package bbotk. Use lgr::get_logger("bbotk") to access and control the logger.

Super class

bbotk::Optimizer -> OptimizerHyperband

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

OptimizerHyperband$new()

Method clone()

The objects of this class are cloneable with this method.

Usage

OptimizerHyperband$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

Run this code
# NOT RUN {
library(bbotk)
library(data.table)

search_space = domain = ps(
  x1 = p_dbl(-5, 10), 
  x2 = p_dbl(0, 15), 
  fidelity = p_dbl(1e-2, 1, tags = "budget")
)

# modified branin function
objective = ObjectiveRFunDt$new(
  fun = function(xdt) {
    a = 1
    b = 5.1 / (4 * (pi ^ 2))
    c = 5 / pi
    r = 6
    s = 10
    t = 1 / (8 * pi)
    data.table(y = 
      (a * ((xdt[["x2"]] - 
      b * (xdt[["x1"]] ^ 2L) + 
      c * xdt[["x1"]] - r) ^ 2) + 
      ((s * (1 - t)) * cos(xdt[["x1"]])) + 
      s - (5 * xdt[["fidelity"]] * xdt[["x1"]])))
  },
  domain = domain,
  codomain = ps(y = p_dbl(tags = "minimize"))
)

instance = OptimInstanceSingleCrit$new(
  objective = objective,
  search_space = search_space,
  terminator = trm("none")
)

optimizer = opt("hyperband")

# modifies the instance by reference
optimizer$optimize(instance)

# best scoring evaluation
instance$result

# all evaluations
as.data.table(instance$archive)
# }

Run the code above in your browser using DataLab