Learn R Programming

hdbayes (version 0.2.0)

compute.ensemble.weights: Compute model averaging weights

Description

Compute model averaging weights for a set of Bayesian models using Bayesian model averaging (BMA), pseudo-BMA, pseudo-BMA+ (pseudo-BMA with the Bayesian bootstrap), or stacking. This function takes a list of model fit objects, each containing posterior samples from a generalized linear model (GLM) or survival model, and returns normalized weights that can be used for model comparison or combining posterior samples using functions like sample.ensemble().

Usage

compute.ensemble.weights(
  fit.list,
  type = c("bma", "pseudobma", "pseudobma+", "stacking"),
  prior.prob = NULL,
  bridge.args = NULL,
  loo.args = NULL,
  loo.wts.args = NULL,
  iter_warmup = 1000,
  iter_sampling = 1000,
  chains = 4,
  ...
)

Value

The function returns a list with the following objects

weights

a numeric vector of normalized model weights corresponding to the models in fit.list. The names of the weights are made unique based on the model identifiers.

type

a character string indicating the method used to compute the model weights (e.g., "bma", "pseudobma", "pseudobma+", or "stacking")

res.logml

a list of log marginal likelihood estimation results, returned only when type = "bma"

loo.list

a list of outputs from loo::loo(), returned only when type is "pseudobma", "pseudobma+", or "stacking"

Arguments

fit.list

a list of model fit objects returned by functions in the hdbayes package. Each fit contains posterior samples from a generalized linear model (GLM) (e.g., via glm.pp()), an accelerated failure time (AFT) model (e.g., via aft.pp()), a piecewise exponential (PWE) model (e.g., via pwe.pp()), or a mixture cure rate model with a PWE component for the non-cured population (e.g., via curepwe.pp()). Each fit also includes two attributes: data, a list of variables specified in the data block of the Stan program, and model, a character string indicating the model name. To compute pseudo-BMA, pseudo-BMA+, or stacking weights, the fitting function must be called with get.loglik = TRUE.

type

a character string specifying the ensemble method used to compute model weights. Options are "bma" (Bayesian model averaging (BMA)), "pseudobma" (pseudo-BMA without the Bayesian bootstrap), "pseudobma+" (pseudo-BMA with the Bayesian bootstrap), and "stacking".

prior.prob

a numeric vector of prior model probabilities, used only when type = "bma". Must be non-negative and sum to 1. If set to NULL, a uniform prior is used (i.e., all models are equally likely). Defaults to NULL.

bridge.args

a list of optional arguments (excluding samples, log_posterior, data, lb, and ub) to be passed to bridgesampling::bridge_sampler(). These arguments are used when estimating the log marginal likelihood, which is required if type = "bma".

loo.args

a list of optional arguments (excluding x) to be passed to loo::loo() when computing pseudo-BMA, pseudo-BMA+, or stacking weights.

loo.wts.args

a list of optional arguments (excluding x, method, and BB) to be passed to loo::loo_model_weights() when computing pseudo-BMA, pseudo-BMA+, or stacking weights.

iter_warmup

number of warmup iterations to run per chain. Used only when computing the log marginal likelihood (i.e., when type = "bma"). Defaults to 1000. See the argument iter_warmup in sample() method in cmdstanr package.

iter_sampling

number of post-warmup iterations to run per chain. Used only when computing the log marginal likelihood (i.e., when type = "bma"). Defaults to 1000. See the argument iter_sampling in sample() method in cmdstanr package.

chains

number of Markov chains to run. Used only when computing the log marginal likelihood (i.e., when type = "bma"). Defaults to 4. See the argument chains in sample() method in cmdstanr package.

...

arguments passed to sample() method in cmdstanr package (e.g., seed, refresh, init). These are used only when computing the log marginal likelihood (i.e., when type = "bma").

Details

The input fit.list should be a list of outputs from model fitting functions in the hdbayes package, such as glm.pp() (for generalized linear models), aft.pp() (for accelerated failure time models), pwe.pp() (for piecewise exponential (PWE) models), or curepwe.pp() (for mixture cure rate models with a PWE component for the non-cured population). To compute pseudo-BMA, pseudo-BMA+, or stacking weights, each fit must include pointwise log-likelihood values. To ensure this, the fitting function must be called with get.loglik = TRUE.

The arguments related to Markov chain Monte Carlo (MCMC) sampling are utilized to compute the logarithm of the normalizing constant for BMA, if applicable.

References

Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). Using stacking to average Bayesian predictive distributions. Bayesian Analysis, 13(3), 917–1007.

Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5), 1413–1432.

See Also

sample.ensemble()

Examples

Run this code
if (instantiate::stan_cmdstan_exists()) {
  if(requireNamespace("survival")){
    library(survival)
    data(E1684)
    data(E1690)
    ## replace 0 failure times with 0.50 days
    E1684$failtime[E1684$failtime == 0] = 0.50/365.25
    E1690$failtime[E1690$failtime == 0] = 0.50/365.25
    E1684$cage = as.numeric(scale(E1684$age))
    E1690$cage = as.numeric(scale(E1690$age))
    data_list = list(currdata = E1690, histdata = E1684)
    nbreaks = 3
    probs   = 1:nbreaks / nbreaks
    breaks  = as.numeric(
      quantile(E1690[E1690$failcens==1, ]$failtime, probs = probs)
    )
    breaks  = c(0, breaks)
    breaks[length(breaks)] = max(10000, 1000 * breaks[length(breaks)])
    fit.pwe.pp = pwe.pp(
      formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
      data.list = data_list,
      breaks = breaks,
      a0 = 0.5,
      get.loglik = TRUE,
      chains = 1, iter_warmup = 1000, iter_sampling = 2000
    )
    fit.pwe.post = pwe.post(
      formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
      data.list = data_list,
      breaks = breaks,
      get.loglik = TRUE,
      chains = 1, iter_warmup = 1000, iter_sampling = 2000
    )
    fit.aft.post = aft.post(
      formula = survival::Surv(failtime, failcens) ~ treatment + sex + cage + node_bin,
      data.list = data_list,
      dist = "weibull",
      beta.sd = 10,
      get.loglik = TRUE,
      chains = 1, iter_warmup = 1000, iter_sampling = 2000
    )
    compute.ensemble.weights(
      fit.list = list(fit.pwe.post, fit.pwe.pp, fit.aft.post),
      type = "pseudobma+",
      loo.args = list(save_psis = FALSE),
      loo.wts.args = list(optim_method="BFGS")
    )
  }
}

Run the code above in your browser using DataLab