vaeac
modelThis function can be applied both in the initialization phase when, we train several initiated vaeac
models, and
to keep training the best performing vaeac
model for the remaining number of epochs. We are in the former setting
when initialization_idx
is provided and the latter when it is NULL
. When it is NULL
, we save the vaeac
models
with lowest VLB, IWAE, running IWAE, and the epochs according to save_every_nth_epoch
to disk.
vaeac_train_model_auxiliary(
vaeac_model,
optimizer,
train_dataloader,
val_dataloader,
val_iwae_n_samples,
running_avg_n_values,
verbose,
cuda,
epochs,
save_every_nth_epoch,
epochs_early_stopping,
epochs_start = 1,
progressr_bar = NULL,
vaeac_save_file_names = NULL,
state_list = NULL,
initialization_idx = NULL,
n_vaeacs_initialize = NULL,
train_vlb = NULL,
val_iwae = NULL,
val_iwae_running = NULL
)
Depending on if we are in the initialization phase or not. Then either the trained vaeac
model, or
a list of where the vaeac
models are stored on disk and the parameters of the model.
A vaeac()
object. The vaeac
model this function is to train.
A torch::optimizer()
object. See vaeac_get_optimizer()
.
A torch::dataloader()
containing the training data for the vaeac
model.
A torch::dataloader()
containing the validation data for the vaeac
model.
Positive integer (default is 25
). The number of generated samples used
to compute the IWAE criterion when validating the vaeac model on the validation data.
running_avg_n_values Positive integer (default is 5
).
The number of previous IWAE values to include
when we compute the running means of the IWAE criterion.
String vector or NULL.
Specifies the verbosity (printout detail level) through one or more of strings "basic"
, "progress"
,
"convergence"
, "shapley"
and "vS_details"
.
"basic"
(default) displays basic information about the computation which is being performed,
in addition to some messages about parameters being sets or checks being unavailable due to specific input.
"progress
displays information about where in the calculation process the function currently is.
#' "convergence"
displays information on how close to convergence the Shapley value estimates are
(only when iterative = TRUE
) .
"shapley"
displays intermediate Shapley value estimates and standard deviations (only when iterative = TRUE
)
and the final estimates.
"vS_details"
displays information about the v_S estimates.
This is most relevant for approach %in% c("regression_separate", "regression_surrogate", "vaeac"
).
NULL
means no printout.
Note that any combination of four strings can be used.
E.g. verbose = c("basic", "vS_details")
will display basic information + details about the v(S)-estimation process.
Logical (default is FALSE
). If TRUE
, then the vaeac
model will be trained using cuda/GPU.
If torch::cuda_is_available()
is FALSE
, the we fall back to use CPU. If FALSE
, we use the CPU. Using a GPU
for smaller tabular dataset often do not improve the efficiency.
See vignette("installation", package = "torch")
fo help to enable running on the GPU (only Linux and Windows).
Positive integer (default is 100
). The number of epochs to train the final vaeac model.
This includes epochs_initiation_phase
, where the default is 2
.
Positive integer (default is NULL
). If provided, then the vaeac model after
every save_every_nth_epoch
th epoch will be saved.
Positive integer (default is NULL
). The training stops if there has been no
improvement in the validation IWAE for epochs_early_stopping
epochs. If the user wants the training process
to be solely based on this training criterion, then epochs
in explain()
should be set to a large
number. If NULL
, then shapr
will internally set epochs_early_stopping = vaeac.epochs
such that early
stopping does not occur.
Positive integer (default is 1
). At which epoch the training is starting at.
A progressr::progressor()
object (default is NULL
) to keep track of progress.
Array of strings containing the save file names for the vaeac
model.
Named list containing the objects returned from vaeac_get_full_state_list()
.
Positive integer (default is NULL
). The index
of the current vaeac
model in the initialization phase.
Positive integer (default is 4
). The number of different vaeac models to initiate
in the start. Pick the best performing one after epochs_initiation_phase
epochs (default is 2
) and continue training that one.
A torch::torch_tensor()
(default is NULL
)
of one dimension containing previous values for the training VLB.
A torch::torch_tensor()
(default is NULL
)
of one dimension containing previous values for the validation IWAE.
A torch::torch_tensor()
(default is NULL
)
of one dimension containing previous values for the running validation IWAE.
Lars Henry Berge Olsen