Learn R Programming

unifiedml (version 0.3.0)

benchmark: Benchmark Multiple Models with Cross-Validation and Model-Specific Parameters

Description

Perform k-fold cross-validation on a list of models, using model-specific parameters. Supports verbose messages and a progress bar.

Usage

benchmark(
  models,
  X,
  y,
  cv = 5L,
  scoring = NULL,
  params = NULL,
  cl = NULL,
  show_progress = FALSE,
  verbose = TRUE
)

Value

A list containing the CV scores for each model.

Arguments

models

A named list of Model$new(...) objects to benchmark.

X

A data frame or matrix of predictors.

y

A vector of outcomes (factor for classification, numeric for regression).

cv

Integer, number of cross-validation folds (default 5).

scoring

Scoring metric: "rmse", "mae", "accuracy", or "f1" (default: auto-detected based on task)

params

Optional named list of lists, each sublist containing extra arguments to pass to the corresponding model's fit() call. Names must match models.

cl

Optional number of clusters for parallel processing

show_progress

Logical, whether to show a progress bar (default TRUE).

verbose

Logical, whether to print messages about each model (default TRUE).

Examples

Run this code
if (FALSE) {
library(randomForest)

X <- iris[, 1:4]
y <- iris$Species

models <- list(
  glm  = Model$new(caret::train),
  rf   = Model$new(randomForest::randomForest),
  xgb  = Model$new(caret::train)
)

params <- list(
  glm = list(method = "glmnet",
             tuneGrid = data.frame(alpha = 0, lambda = 0.01),
             trControl = trainControl(method = "none")),
  rf  = list(ntree = 150),
  xgb = list(method = "xgbTree",
             tuneGrid = data.frame(nrounds = 150, max_depth = 3, eta = 0.3,
                                   gamma = 0, colsample_bytree = 1,
                                   min_child_weight = 1, subsample = 1),
             trControl = trainControl(method = "none"))
)

results <- benchmark(models, X, y, cv = 5, params = params,
                     show_progress = TRUE, verbose = TRUE)
print(results)
}

Run the code above in your browser using DataLab