Learn R Programming

luz (version 0.5.0)

luz_callback_auto_resume: Resume training callback

Description

This callback allows you to resume training a model.

Usage

luz_callback_auto_resume(path = "./state.pt")

Arguments

path

Path to save state files for the model.

Customizing serialization

By default model, optimizer state and records are serialized. Callbacks can be used to customize serialization by implementing the state_dict() and load_state_dict() methods. If those methods are implemented, then state_dict() is called at the end of each epoch and load_state_dict() is called when the model is resumed.

Details

When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.

See Also

Other luz_callbacks: luz_callback(), luz_callback_csv_logger(), luz_callback_early_stopping(), luz_callback_interrupt(), luz_callback_keep_best_model(), luz_callback_lr_scheduler(), luz_callback_metrics(), luz_callback_mixed_precision(), luz_callback_mixup(), luz_callback_model_checkpoint(), luz_callback_profile(), luz_callback_progress(), luz_callback_resume_from_checkpoint(), luz_callback_train_valid()

Examples

Run this code
if (torch::torch_is_installed()) {
library(torch)
library(luz)

x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)

model <- nn_linear %>%
  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 0.01)


# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
  "interrupt",
  failed = FALSE,
  on_epoch_end = function() {
    if (ctx$epoch == 5 && !self$failed) {
      self$failed <- TRUE
      stop("Error on epoch 5")
    }
  }
)

path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()

# try once and the model fails
try({
  results <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, interrupt),
    verbose = FALSE
  )
})

# model resumes and completes
results <- model %>% fit(
  list(x, y),
  callbacks = list(autoresume, interrupt),
  verbose = FALSE
)

get_metrics(results)

}

Run the code above in your browser using DataLab