Learn R Programming

ggmlR (version 0.6.1)

ggml_fit_opt: Fit model with R-side epoch loop and callbacks

Description

Trains a model epoch by epoch in R, allowing callbacks for early stopping and learning rate scheduling. Optimizer state (momentum) is preserved across all epochs.

Usage

ggml_fit_opt(
  sched,
  ctx_compute,
  inputs,
  outputs,
  dataset,
  loss_type = ggml_opt_loss_type_mse(),
  optimizer = ggml_opt_optimizer_type_adamw(),
  nepoch = 10L,
  nbatch_logical = 32L,
  val_split = 0,
  callbacks = list(),
  silent = FALSE
)

Value

Data frame with columns epoch, train_loss, train_accuracy, val_loss, val_accuracy

Arguments

sched

Backend scheduler

ctx_compute

Compute context (for temporary tensors)

inputs

Input tensor with shape [ne_datapoint, batch_size]

outputs

Output tensor with shape [ne_label, batch_size]

dataset

Dataset created with `ggml_opt_dataset_init()`

loss_type

Loss type (default: MSE)

optimizer

Optimizer type (default: AdamW)

nepoch

Number of epochs

nbatch_logical

Logical batch size (for gradient accumulation)

val_split

Fraction of data for validation (0.0 to 1.0)

callbacks

List of callback lists. Each element may have `on_epoch_begin(epoch, logs, state)` and/or `on_epoch_end(epoch, logs, state)`. Built-in factories: `ggml_callback_early_stopping()`, `ggml_schedule_step_decay()`, `ggml_schedule_cosine_decay()`, `ggml_schedule_reduce_on_plateau()`. `state` is a mutable environment with fields: `stop` (set TRUE to stop training), `lr_ud`, `nepoch`.

silent

Whether to suppress per-epoch progress output

See Also

Other optimization: ggml_opt_alloc(), ggml_opt_context_optimizer_type(), ggml_opt_dataset_data(), ggml_opt_dataset_free(), ggml_opt_dataset_get_batch(), ggml_opt_dataset_init(), ggml_opt_dataset_labels(), ggml_opt_dataset_ndata(), ggml_opt_dataset_shuffle(), ggml_opt_default_params(), ggml_opt_epoch(), ggml_opt_eval(), ggml_opt_fit(), ggml_opt_free(), ggml_opt_get_lr(), ggml_opt_grad_acc(), ggml_opt_init(), ggml_opt_init_for_fit(), ggml_opt_inputs(), ggml_opt_labels(), ggml_opt_loss(), ggml_opt_loss_type_cross_entropy(), ggml_opt_loss_type_mean(), ggml_opt_loss_type_mse(), ggml_opt_loss_type_sum(), ggml_opt_ncorrect(), ggml_opt_optimizer_name(), ggml_opt_optimizer_type_adamw(), ggml_opt_optimizer_type_sgd(), ggml_opt_outputs(), ggml_opt_pred(), ggml_opt_prepare_alloc(), ggml_opt_reset(), ggml_opt_result_accuracy(), ggml_opt_result_free(), ggml_opt_result_init(), ggml_opt_result_loss(), ggml_opt_result_ndata(), ggml_opt_result_pred(), ggml_opt_result_reset(), ggml_opt_set_lr(), ggml_opt_static_graphs()

Examples

Run this code
if (FALSE) {
history <- ggml_fit_opt(sched, ctx_compute, inputs, outputs, dataset,
  nepoch = 50, val_split = 0.2,
  callbacks = list(
    ggml_callback_early_stopping(monitor = "val_loss", patience = 5),
    ggml_schedule_cosine_decay()
  ))
}

Run the code above in your browser using DataLab