Runs the full AutoTab training loop (encoder + decoder + latent space), with optional Beta-annealing (linear or cyclical), optional Gumbel-softmax temperature warming for categorical outputs, and options for the prior.
VAE_train(
data,
encoder_info,
decoder_info,
Lip_en,
pi_enc = 1,
lip_dec,
pi_dec = 1,
latent_dim,
epoch,
beta,
kl_warm = FALSE,
kl_cyclical = FALSE,
n_cycles,
ratio,
beta_epoch = 15,
temperature,
temp_warm = FALSE,
temp_epoch,
batchsize,
wait,
min_delta = 0.001,
lr,
max_std = 10,
min_val = 0.001,
weighted = 0,
recon_weights,
seperate = 0,
prior = "single_gaussian",
K = 3,
learnable_mog = FALSE,
mog_means = NULL,
mog_log_vars = NULL,
mog_weights = NULL
)A list with:
trained_model — the compiled Keras model (encoder→decoder) with KL and recon losses added.
loss_history — numeric vector of per-epoch total loss (as tracked during training).
Matrix/data.frame. Preprocessed training data (columns match
the order in feat_dist).
Lists describing layer stacks. Each element
is e.g. list("dense", units, "activation", L2_flag, L2_lambda, BN_flag, BN_momentum, BN_learn) or
list("dropout", rate).
Integer (0/1). Use spectral normalization (Lipschitz) in encoder/decoder.
Integer. Power-iteration counts for spectral normalization.
Integer. Latent dimensionality.
Integer. Max training epochs.
Numeric. Beta-VAE weight on the KL term in the ELBO.
Logical. Enable Beta-annealing.
Logical. Enable cyclical Beta-annealing (requires kl_warm = TRUE).
Integer. Number of cycles when kl_cyclical = TRUE.
Numeric from range 0 to 1. Fraction of each cycle used for warm-up (rise from 0→Beta).
Integer. Warm-up length (epochs) for linear Beta-annealing; when
kl_cyclical = TRUE, the cycle length is (beta_epoch / n_cycles).
Numeric. Gumbel-softmax temperature (used for categorical heads).
Logical. Enable temperature warm-up.
Integer. Warm-up length (epochs) for temperature when temp_warm = TRUE.
Integer. Mini-batch size.
Integer. Early-stopping patience (epochs) on validation reconstruction loss.
Numeric. Minimum improvement to reset patience (early stopping).
Numeric. Learning rate (Adam).
Numerics. Decoder constraints for Gaussian heads (max SD; minimum variance surrogate).
Integer (0/1). If 1, weight reconstruction terms by type.
Numeric length-3. Weights for (continuous, binary, categorical);
required when weighted = 1.
Integer (0/1). If 1, logs per-group reconstruction losses as metrics
(cont_loss, bin_loss, cat_loss) in addition to total recon_loss.
Character. "single_gaussian" or "mixture_gaussian".
Integer. Number of mixture components when prior = "mixture_gaussian".
Logical. If TRUE, MoG prior parameters are trainable.
Optional initial values for the MoG prior
(ignored unless prior = "mixture_gaussian"; when learnable_mog = FALSE they must be provided).
Prerequisite: call set_feat_dist() once before training to register the
per-feature distributions and parameter counts (see extracting_distribution()
and feat_reorder()).
Metrics exposed during training: loss, recon_loss, kl_loss, and,
when seperate = 1, cont_loss, bin_loss, cat_loss, and, beta, temperature
when annealed.
Early stopping: monitored on val_recon_loss with patience = wait.
Reproducibility: set seeds via your own workflow or the helper reset_seeds().
Expected Warning: When running AutoTab the user will receive the following warning from tensorflow: "WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.math.multiply_3), but are not present in its tracked objects: <tf.Variable 'beta:0' shape=() dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer."
This is merely a warning and should not effect the computation of AutoTab. This occurs because tensorflow does not see beta, (the weight on the regularization part of the ELBO) until after the first iteration of training and the first computation of the loss is initiated. Therefore it is not an internally tracked object. However, it is being tracked and updated outside of the model graph which can be seen in the KL loss plots and in the training printout in the R console.
set_feat_dist(), extracting_distribution(), feat_reorder(),
Encoder_weights(), encoder_latent(), Decoder_weights(), Latent_sample()