Learn R Programming

neuralGAM (version 2.0.1)

plot_history: Plot training loss history for a neuralGAM model

Description

This function visualizes the training and/or validation loss at the end of each backfitting iteration for each term-specific model in a fitted neuralGAM object. It is designed to work with the history component of a trained neuralGAM model.

Usage

plot_history(model, select = NULL, metric = c("loss", "val_loss"))

Value

A ggplot object showing the loss curves by backfitting iteration, with facets per term.

Arguments

model

A fitted neuralGAM model.

select

Optional character vector of term names (e.g. "x1" or c("x1", "x3")) to subset the history. If NULL (default), all terms are included.

metric

Character vector indicating which loss metric(s) to plot. Options are "loss", "val_loss", or both. Defaults to both.

Author

Ines Ortega-Fernandez, Marta Sestelo

Examples

Run this code
if (FALSE) {
  set.seed(123)
  n <- 200
  x1 <- runif(n, -2, 2)
  x2 <- runif(n, -2, 2)
  y <- 2 + x1^2 + sin(x2) + rnorm(n, 0, 0.1)
  df <- data.frame(x1 = x1, x2 = x2, y = y)

  model <- neuralGAM::neuralGAM(
    y ~ s(x1) + s(x2),
    data = df,
    num_units = 8,
    family = "gaussian",
    max_iter_backfitting = 2,
    max_iter_ls = 1,
    learning_rate = 0.01,
    seed = 42,
    validation_split = 0.2,
    verbose = 0
  )

  plot_history(model)                      # Plot all terms
  plot_history(model, select = "x1")       # Plot just x1
  plot_history(model, metric = "val_loss") # Plot only validation loss
}

Run the code above in your browser using DataLab