Trains a subgroup-aware gated neural network with a fairness-constrained loss, optionally performs hyperparameter tuning with lightweight budgets, and returns predictions, gate/expert weights, and summary metrics. Designed to be CRAN-safe:
No background installs, no saving unless requested
CPU-only torch, capped threads to avoid oversubscription
train_gnn(
prepared_data,
hyper_grid,
num_repeats = 20,
epochs = 300,
output_dir = tempdir(),
run_tuning = TRUE,
best_params = NULL,
save_outputs = FALSE,
seed = NULL,
verbose = FALSE,
tune_repeats = NULL,
tune_epochs = NULL
)A list with:
final_results (tibble: subjectid, true, prob, group, iteration)
gate_weights (tibble with gate probabilities & entropy per subject/iteration)
expert_weights (list of expert input-layer weight matrices per repeat)
performance_summary (tibble with AUC and Brier)
aif360_data (tibble for fairness metric tooling)
tuning_results (tibble or message when tuning skipped)
List from prepare_data() containing:
X (matrix/data.frame of numeric features)
y (numeric 0/1)
group (numeric codes for sensitive subgroup)
feature_names (character vector; optional)
subject_ids (vector; optional)
data.frame with columns: lr, hidden_dim, dropout_rate, lambda, temperature.
Integer (>=1). Repeated train/test splits for the final model (and for tuning if tune_repeats is not set).
Integer (>=1). Training epochs per run for the final model (and for tuning if tune_epochs is not set).
Directory to write csv/rds if save_outputs = TRUE. Defaults to tempdir().
Logical. If TRUE, runs a grid search using hyper_grid and picks best by mean AUC.
data.frame/list with lr, hidden_dim, dropout_rate, lambda, temperature if run_tuning = FALSE.
Logical. If TRUE, writes CSV/RDS outputs to output_dir. Default FALSE.
Optional integer seed to make data splits reproducible. If NULL, current RNG state is respected.
Logical. Print progress messages. Default FALSE.
Integer (>=1). Repeats per combo during tuning only. Defaults to min(5, num_repeats).
Integer (>=1). Epochs per run during tuning only. Defaults to min(epochs, 100).