Learn R Programming

survdnn (version 0.7.5)

gridsearch_survdnn: Grid Search for survdnn Hyperparameters

Description

Performs grid search over user-specified hyperparameters and evaluates performance on a validation set.

Usage

gridsearch_survdnn(
  formula,
  train,
  valid,
  times,
  metrics = c("cindex", "ibs"),
  param_grid,
  .seed = 42,
  .device = c("auto", "cpu", "cuda")
)

Value

A tibble with configurations and their validation metrics

Arguments

formula

A survival formula (e.g., `Surv(time, status) ~ .`)

train

Training dataset

valid

Validation dataset

times

Evaluation time points (numeric vector)

metrics

Evaluation metrics (character vector): any of "cindex" and "ibs".

param_grid

A named list of hyperparameters to search over. Currently supported entries are hidden, lr, activation, epochs, and loss.

.seed

Optional random seed for reproducibility

.device

Character string indicating the computation device used when fitting all models in the grid search. One of "auto", "cpu", or "cuda". This is a runtime setting and is not part of the hyperparameter grid.

Examples

Run this code
# \donttest{
library(survdnn)
library(survival)
set.seed(123)

# Simulate small dataset
n <- 300
x1 <- rnorm(n); x2 <- rbinom(n, 1, 0.5)
time <- rexp(n, rate = 0.1)
status <- rbinom(n, 1, 0.7)
df <- data.frame(time, status, x1, x2)

# Split into training and validation
idx <- sample(seq_len(n), 0.7 * n)
train <- df[idx, ]
valid <- df[-idx, ]

# Define formula and param grid
formula <- Surv(time, status) ~ x1 + x2
param_grid <- list(
  hidden     = list(c(16, 8), c(32, 16)),
  lr         = c(1e-3),
  activation = c("relu"),
  epochs     = c(100),
  loss       = c("cox", "coxtime")
)

# Run grid search
results <- gridsearch_survdnn(
  formula = formula,
  train   = train,
  valid   = valid,
  times   = c(10, 20, 30),
  metrics = c("cindex", "ibs"),
  param_grid = param_grid
)

# View summary
dplyr::group_by(results, hidden, lr, activation, epochs, loss, metric) |>
  dplyr::summarise(mean = mean(value, na.rm = TRUE), .groups = "drop")
# }

Run the code above in your browser using DataLab