Learn R Programming

autotab (version 0.1.3)

VAE_train: Train an AutoTab VAE on mixed-type tabular data

Description

Runs the full AutoTab training loop (encoder + decoder + latent space), with optional Beta-annealing (linear or cyclical), optional Gumbel-softmax temperature warming for categorical outputs, and options for the prior.

Usage

VAE_train(
  data,
  encoder_info,
  decoder_info,
  Lip_en,
  pi_enc = 1,
  lip_dec,
  pi_dec = 1,
  latent_dim,
  epoch,
  beta,
  kl_warm = FALSE,
  kl_cyclical = FALSE,
  n_cycles,
  ratio,
  beta_epoch = 15,
  temperature,
  temp_warm = FALSE,
  temp_epoch,
  batchsize,
  wait,
  min_delta = 0.001,
  lr,
  max_std = 10,
  min_val = 0.001,
  weighted = 0,
  recon_weights,
  seperate = 0,
  prior = "single_gaussian",
  K = 3,
  learnable_mog = FALSE,
  mog_means = NULL,
  mog_log_vars = NULL,
  mog_weights = NULL
)

Value

A list with:

  • trained_model — the compiled Keras model (encoder→decoder) with KL and recon losses added.

  • loss_history — numeric vector of per-epoch total loss (as tracked during training).

Arguments

data

Matrix/data.frame. Preprocessed training data (columns match the order in feat_dist).

encoder_info, decoder_info

Lists describing layer stacks. Each element is e.g. list("dense", units, "activation", L2_flag, L2_lambda, BN_flag, BN_momentum, BN_learn) or list("dropout", rate).

Lip_en, lip_dec

Integer (0/1). Use spectral normalization (Lipschitz) in encoder/decoder.

pi_enc, pi_dec

Integer. Power-iteration counts for spectral normalization.

latent_dim

Integer. Latent dimensionality.

epoch

Integer. Max training epochs.

beta

Numeric. Beta-VAE weight on the KL term in the ELBO.

kl_warm

Logical. Enable Beta-annealing.

kl_cyclical

Logical. Enable cyclical Beta-annealing (requires kl_warm = TRUE).

n_cycles

Integer. Number of cycles when kl_cyclical = TRUE.

ratio

Numeric from range 0 to 1. Fraction of each cycle used for warm-up (rise from 0→Beta).

beta_epoch

Integer. Warm-up length (epochs) for linear Beta-annealing; when kl_cyclical = TRUE, the cycle length is (beta_epoch / n_cycles).

temperature

Numeric. Gumbel-softmax temperature (used for categorical heads).

temp_warm

Logical. Enable temperature warm-up.

temp_epoch

Integer. Warm-up length (epochs) for temperature when temp_warm = TRUE.

batchsize

Integer. Mini-batch size.

wait

Integer. Early-stopping patience (epochs) on validation reconstruction loss.

min_delta

Numeric. Minimum improvement to reset patience (early stopping).

lr

Numeric. Learning rate (Adam).

max_std, min_val

Numerics. Decoder constraints for Gaussian heads (max SD; minimum variance surrogate).

weighted

Integer (0/1). If 1, weight reconstruction terms by type.

recon_weights

Numeric length-3. Weights for (continuous, binary, categorical); required when weighted = 1.

seperate

Integer (0/1). If 1, logs per-group reconstruction losses as metrics (cont_loss, bin_loss, cat_loss) in addition to total recon_loss.

prior

Character. "single_gaussian" or "mixture_gaussian".

K

Integer. Number of mixture components when prior = "mixture_gaussian".

learnable_mog

Logical. If TRUE, MoG prior parameters are trainable.

mog_means, mog_log_vars, mog_weights

Optional initial values for the MoG prior (ignored unless prior = "mixture_gaussian"; when learnable_mog = FALSE they must be provided).

Details

Prerequisite: call set_feat_dist() once before training to register the per-feature distributions and parameter counts (see extracting_distribution() and feat_reorder()).

Metrics exposed during training: loss, recon_loss, kl_loss, and, when seperate = 1, cont_loss, bin_loss, cat_loss, and, beta, temperature when annealed.

Early stopping: monitored on val_recon_loss with patience = wait.

Reproducibility: set seeds via your own workflow or the helper reset_seeds().

Expected Warning: When running AutoTab the user will receive the following warning from tensorflow: "WARNING:tensorflow:The following Variables were used in a Lambda layer's call (tf.math.multiply_3), but are not present in its tracked objects: <tf.Variable 'beta:0' shape=() dtype=float32>. This is a strong indication that the Lambda layer should be rewritten as a subclassed Layer."

This is merely a warning and should not effect the computation of AutoTab. This occurs because tensorflow does not see beta, (the weight on the regularization part of the ELBO) until after the first iteration of training and the first computation of the loss is initiated. Therefore it is not an internally tracked object. However, it is being tracked and updated outside of the model graph which can be seen in the KL loss plots and in the training printout in the R console.

See Also

set_feat_dist(), extracting_distribution(), feat_reorder(), Encoder_weights(), encoder_latent(), Decoder_weights(), Latent_sample()