if (torch::torch_is_installed()) {
library(torch)
library(luz)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
"interrupt",
failed = FALSE,
on_epoch_end = function() {
if (ctx$epoch == 5 && !self$failed) {
self$failed <- TRUE
stop("Error on epoch 5")
}
}
)
path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()
# try once and the model fails
try({
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
})
# model resumes and completes
results <- model %>% fit(
list(x, y),
callbacks = list(autoresume, interrupt),
verbose = FALSE
)
get_metrics(results)
}
Run the code above in your browser using DataLab