Learn R Programming

survdnn (version 0.7.5)

tune_survdnn: Tune Hyperparameters for a survdnn Model via Cross-Validation

Description

Performs k-fold cross-validation over a user-defined hyperparameter grid and selects the best configuration according to the specified evaluation metric.

Usage

tune_survdnn(
  formula,
  data,
  times,
  metrics = "cindex",
  param_grid,
  folds = 3,
  .seed = 42,
  .device = c("auto", "cpu", "cuda"),
  na_action = c("omit", "fail"),
  refit = FALSE,
  return = c("all", "summary", "best_model")
)

Value

A tibble or model object depending on the `return` value.

Arguments

formula

A survival formula, e.g., `Surv(time, status) ~ x1 + x2`.

data

A data frame.

times

A numeric vector of evaluation time points.

metrics

A character vector of evaluation metrics: "cindex", "brier", or "ibs". Only the first metric is used for model selection.

param_grid

A named list defining hyperparameter combinations to evaluate. Required names: `hidden`, `lr`, `activation`, `epochs`, `loss`.

folds

Number of cross-validation folds (default: 3).

.seed

Optional seed for reproducibility.

.device

Character string indicating the computation device used when fitting models during cross-validation and refitting. One of `"auto"`, `"cpu"`, or `"cuda"`. `"auto"` uses CUDA if available, otherwise falls back to CPU.

na_action

Character. How to handle missing values: `"omit"` drops incomplete rows; `"fail"` errors if any NA is present.

refit

Logical. If TRUE, refits the best model on the full dataset.

return

One of "all", "summary", or "best_model":

"all"

Returns the full cross-validation result across all combinations.

"summary"

Returns averaged results per configuration.

"best_model"

Returns the refitted model or best hyperparameters.