Learn R Programming

mlr3proba (version 0.4.9)

mlr_pipeops_survavg: PipeOpSurvAvg

Description

Perform (weighted) prediction averaging from survival PredictionSurvs by connecting PipeOpSurvAvg to multiple PipeOpLearner outputs.

The resulting prediction will aggregate any predict types that are contained within all inputs. Any predict types missing from at least one input will be set to NULL. These are aggregated as follows:

Weights can be set as a parameter; if none are provided, defaults to equal weights for each prediction.

Arguments

Input and Output Channels

Input and output channels are inherited from PipeOpEnsemble with a PredictionSurv for inputs and outputs.

State

The $state is left empty (list()).

Parameters

The parameters are the parameters inherited from the PipeOpEnsemble.

Internals

Inherits from PipeOpEnsemble by implementing the private$weighted_avg_predictions() method.

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpEnsemble -> PipeOpSurvAvg

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

PipeOpSurvAvg$new(innum = 0, id = "survavg", param_vals = list(), ...)

Arguments

innum

(numeric(1)) Determines the number of input channels. If innum is 0 (default), a vararg input channel is created that can take an arbitrary number of inputs.

id

(character(1)) Identifier of the resulting object.

param_vals

(list()) List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.

...

ANY Additional arguments passed to mlr3pipelines::PipeOpEnsemble.

Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpSurvAvg$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

See Also

Other PipeOps: PipeOpPredTransformer, PipeOpTaskTransformer, PipeOpTransformer, mlr_pipeops_trafopred_regrsurv, mlr_pipeops_trafopred_survregr, mlr_pipeops_trafotask_regrsurv, mlr_pipeops_trafotask_survregr

Examples

Run this code
# NOT RUN {
if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
  library(mlr3)
  library(mlr3pipelines)

  task = tsk("rats")
  p1 = lrn("surv.coxph")$train(task)$predict(task)
  p2 = lrn("surv.kaplan")$train(task)$predict(task)
  poc = po("survavg", param_vals = list(weights = c(0.2, 0.8)))
  poc$predict(list(p1, p2))
}
# }

Run the code above in your browser using DataLab