Trains a model epoch by epoch in R, allowing callbacks for early stopping and learning rate scheduling. Optimizer state (momentum) is preserved across all epochs.
ggml_fit_opt(
sched,
ctx_compute,
inputs,
outputs,
dataset,
loss_type = ggml_opt_loss_type_mse(),
optimizer = ggml_opt_optimizer_type_adamw(),
nepoch = 10L,
nbatch_logical = 32L,
val_split = 0,
callbacks = list(),
silent = FALSE
)Data frame with columns epoch, train_loss, train_accuracy, val_loss, val_accuracy
Backend scheduler
Compute context (for temporary tensors)
Input tensor with shape [ne_datapoint, batch_size]
Output tensor with shape [ne_label, batch_size]
Dataset created with `ggml_opt_dataset_init()`
Loss type (default: MSE)
Optimizer type (default: AdamW)
Number of epochs
Logical batch size (for gradient accumulation)
Fraction of data for validation (0.0 to 1.0)
List of callback lists. Each element may have `on_epoch_begin(epoch, logs, state)` and/or `on_epoch_end(epoch, logs, state)`. Built-in factories: `ggml_callback_early_stopping()`, `ggml_schedule_step_decay()`, `ggml_schedule_cosine_decay()`, `ggml_schedule_reduce_on_plateau()`. `state` is a mutable environment with fields: `stop` (set TRUE to stop training), `lr_ud`, `nepoch`.
Whether to suppress per-epoch progress output
Other optimization:
ggml_opt_alloc(),
ggml_opt_context_optimizer_type(),
ggml_opt_dataset_data(),
ggml_opt_dataset_free(),
ggml_opt_dataset_get_batch(),
ggml_opt_dataset_init(),
ggml_opt_dataset_labels(),
ggml_opt_dataset_ndata(),
ggml_opt_dataset_shuffle(),
ggml_opt_default_params(),
ggml_opt_epoch(),
ggml_opt_eval(),
ggml_opt_fit(),
ggml_opt_free(),
ggml_opt_get_lr(),
ggml_opt_grad_acc(),
ggml_opt_init(),
ggml_opt_init_for_fit(),
ggml_opt_inputs(),
ggml_opt_labels(),
ggml_opt_loss(),
ggml_opt_loss_type_cross_entropy(),
ggml_opt_loss_type_mean(),
ggml_opt_loss_type_mse(),
ggml_opt_loss_type_sum(),
ggml_opt_ncorrect(),
ggml_opt_optimizer_name(),
ggml_opt_optimizer_type_adamw(),
ggml_opt_optimizer_type_sgd(),
ggml_opt_outputs(),
ggml_opt_pred(),
ggml_opt_prepare_alloc(),
ggml_opt_reset(),
ggml_opt_result_accuracy(),
ggml_opt_result_free(),
ggml_opt_result_init(),
ggml_opt_result_loss(),
ggml_opt_result_ndata(),
ggml_opt_result_pred(),
ggml_opt_result_reset(),
ggml_opt_set_lr(),
ggml_opt_static_graphs()
if (FALSE) {
history <- ggml_fit_opt(sched, ctx_compute, inputs, outputs, dataset,
nepoch = 50, val_split = 0.2,
callbacks = list(
ggml_callback_early_stopping(monitor = "val_loss", patience = 5),
ggml_schedule_cosine_decay()
))
}
Run the code above in your browser using DataLab