vaeac modelThis function can be applied both in the initialization phase when, we train several initiated vaeac models, and
to keep training the best performing vaeac model for the remaining number of epochs. We are in the former setting
when initialization_idx is provided and the latter when it is NULL. When it is NULL, we save the vaeac models
with lowest VLB, IWAE, running IWAE, and the epochs according to save_every_nth_epoch to disk.
vaeac_train_model_auxiliary(
vaeac_model,
optimizer,
train_dataloader,
val_dataloader,
val_iwae_n_samples,
running_avg_n_values,
verbose,
cuda,
epochs,
save_every_nth_epoch,
epochs_early_stopping,
epochs_start = 1,
progressr_bar = NULL,
vaeac_save_file_names = NULL,
state_list = NULL,
initialization_idx = NULL,
n_vaeacs_initialize = NULL,
train_vlb = NULL,
val_iwae = NULL,
val_iwae_running = NULL
)Depending on if we are in the initialization phase or not. Then either the trained vaeac model, or
a list of where the vaeac models are stored on disk and the parameters of the model.
A vaeac() object. The vaeac model this function is to train.
A torch::optimizer() object. See vaeac_get_optimizer().
A torch::dataloader() containing the training data for the vaeac model.
A torch::dataloader() containing the validation data for the vaeac model.
Positive integer (default is 25). The number of generated samples used
to compute the IWAE criterion when validating the vaeac model on the validation data.
running_avg_n_values Positive integer (default is 5).
The number of previous IWAE values to include
when we compute the running means of the IWAE criterion.
String vector or NULL.
Controls verbosity (printout detail level) via one or more of "basic", "progress",
"convergence", "shapley" and "vS_details".
"basic" (default) displays basic information about the computation and messages about parameters/checks.
"progress" displays where in the calculation process the function currently is.
"convergence" displays how close the Shapley value estimates are to convergence
(only when iterative = TRUE).
"shapley" displays intermediate Shapley value estimates and standard deviations (only when iterative = TRUE),
and the final estimates.
"vS_details" displays information about the v(S) estimates,
most relevant for approach %in% c("regression_separate", "regression_surrogate", "vaeac").
NULL means no printout.
Any combination can be used, e.g., verbose = c("basic", "vS_details").
Logical (default is FALSE). If TRUE, then the vaeac model will be trained using cuda/GPU.
If torch::cuda_is_available() is FALSE, we fall back to using the CPU. Using a GPU
for smaller tabular dataset often do not improve the efficiency.
See vignette("installation", package = "torch") fo help to enable running on the GPU (only Linux and Windows).
Positive integer (default is 100). The number of epochs to train the final vaeac model.
This includes epochs_initiation_phase, where the default is 2.
Positive integer (default is NULL). If provided, then the vaeac model after
every save_every_nth_epochth epoch will be saved.
Positive integer (default is NULL). The training stops if there has been no
improvement in the validation IWAE for epochs_early_stopping epochs. If the user wants the training process
to be solely based on this training criterion, then epochs in explain() should be set to a large
number. If NULL, then shapr will internally set epochs_early_stopping = vaeac.epochs such that early
stopping does not occur.
Positive integer (default is 1). At which epoch the training is starting at.
A progressr::progressor() object (default is NULL) to keep track of progress.
Array of strings containing the save file names for the vaeac model.
Named list containing the objects returned from vaeac_get_full_state_list().
Positive integer (default is NULL). The index
of the current vaeac model in the initialization phase.
Positive integer (default is 4). The number of different vaeac models to initiate
in the start. Pick the best performing one after epochs_initiation_phase
epochs (default is 2) and continue training that one.
A torch::torch_tensor() (default is NULL)
of one dimension containing previous values for the training VLB.
A torch::torch_tensor() (default is NULL)
of one dimension containing previous values for the validation IWAE.
A torch::torch_tensor() (default is NULL)
of one dimension containing previous values for the running validation IWAE.
Lars Henry Berge Olsen