Learn R Programming

fairGATE (version 0.1.1)

train_gnn: Train and Evaluate the Gated Neural Network (robust splits + safe ROC)

Description

Trains a subgroup-aware gated neural network with a fairness-constrained loss, optionally performs hyperparameter tuning with lightweight budgets, and returns predictions, gate/expert weights, and summary metrics. Designed to be CRAN-safe:

  • No background installs, no saving unless requested

  • CPU-only torch, capped threads to avoid oversubscription

Usage

train_gnn(
  prepared_data,
  hyper_grid,
  num_repeats = 20,
  epochs = 300,
  output_dir = tempdir(),
  run_tuning = TRUE,
  best_params = NULL,
  save_outputs = FALSE,
  seed = NULL,
  verbose = FALSE,
  tune_repeats = NULL,
  tune_epochs = NULL
)

Value

A list with:

  • final_results (tibble: subjectid, true, prob, group, iteration)

  • gate_weights (tibble with gate probabilities & entropy per subject/iteration)

  • expert_weights (list of expert input-layer weight matrices per repeat)

  • performance_summary (tibble with AUC and Brier)

  • aif360_data (tibble for fairness metric tooling)

  • tuning_results (tibble or message when tuning skipped)

Arguments

prepared_data

List from prepare_data() containing:

  • X (matrix/data.frame of numeric features)

  • y (numeric 0/1)

  • group (numeric codes for sensitive subgroup)

  • feature_names (character vector; optional)

  • subject_ids (vector; optional)

hyper_grid

data.frame with columns: lr, hidden_dim, dropout_rate, lambda, temperature.

num_repeats

Integer (>=1). Repeated train/test splits for the final model (and for tuning if tune_repeats is not set).

epochs

Integer (>=1). Training epochs per run for the final model (and for tuning if tune_epochs is not set).

output_dir

Directory to write csv/rds if save_outputs = TRUE. Defaults to tempdir().

run_tuning

Logical. If TRUE, runs a grid search using hyper_grid and picks best by mean AUC.

best_params

data.frame/list with lr, hidden_dim, dropout_rate, lambda, temperature if run_tuning = FALSE.

save_outputs

Logical. If TRUE, writes CSV/RDS outputs to output_dir. Default FALSE.

seed

Optional integer seed to make data splits reproducible. If NULL, current RNG state is respected.

verbose

Logical. Print progress messages. Default FALSE.

tune_repeats

Integer (>=1). Repeats per combo during tuning only. Defaults to min(5, num_repeats).

tune_epochs

Integer (>=1). Epochs per run during tuning only. Defaults to min(epochs, 100).