Learn R Programming

autotab (version 0.1.1)

Decoder_weights: Extract decoder-only weights from a trained Keras model

Description

Pulls just the decoder weights from keras::get_weights(trained_model), skipping encoder parameters and (if used) the final trainable tensors from a learnable mixture-of-Gaussians (MoG) prior (means, log_vars, and weight logits).

Usage

Decoder_weights(
  encoder_layers,
  trained_model,
  lip_enc,
  pi_enc,
  prior_learn,
  BNenc_layers,
  learn_BN
)

Value

A list() of decoder weight tensors in order, suitable for set_weights().

Arguments

encoder_layers

Integer. Number of encoder layers (used to compute split index).

trained_model

Keras model. Typically training$trained_model.

lip_enc

Integer (0/1). Whether spectral normalization was used in the encoder.

pi_enc

Integer. Power iterations used in encoder spectral normalization.

prior_learn

Character. "fixed" for fixed prior; any other value implies learnable MoG.

BNenc_layers

Integer. Number of encoder BN layers (affects split index).

learn_BN

Integer (0/1). Whether BN layers learned scale and center.

Details

  • When prior_learn != "fixed", the final three tensors are assumed to belong to the learnable MoG prior (mog_means, mog_log_vars, mog_weights_logit) and are excluded.

  • The split index math mirrors Encoder_weights() and assumes the standard AutoTab graph wiring.

  • All model weights can always be accessed directly using keras::get_weights(trained_model). This function is provided as a convenience tool within AutoTab to streamline decoder reconstruction but is not the only method available.

See Also

decoder_model(), Encoder_weights(), VAE_train()

Examples

Run this code
decoder_info <- list(
  list("dense", 80, "relu"),
  list("dense", 100, "relu")
)
# \donttest{
if (reticulate::py_module_available("tensorflow") &&
    exists("training")) {
weights_decoder <- Decoder_weights(
  encoder_layers = 2,
  trained_model  = training$trained_model,  #where training = VAE_train(...)
  lip_enc        = 0,
  pi_enc         = 0,
  prior_learn    = "fixed",
  BNenc_layers   = 0,
  learn_BN       = 0
)
}
# }

Run the code above in your browser using DataLab