Learn R Programming

VBphenoR (version 1.1.0)

logit_CAVI: Variational inference for Bayesian logistic regression using CAVI algorithm

Description

Variational inference for Bayesian logistic regression using CAVI algorithm

Usage

logit_CAVI(
  X,
  y,
  prior,
  delta = 1e-16,
  maxiters = 10000,
  verbose = FALSE,
  progressbar = TRUE
)

Value

A list containing:

  • error - An error message if convergence failed or the number of iterations to achieve convergence.

  • mu - A vector of posterior means.

  • Sigma - A vector of posterior variances.

  • Convergence - A vector of the ELBO at each iteration.

  • LBDifference - A vector of ELBO differences between each iteration.

  • xi - A vector of log-odds per X row.

Arguments

X

The input design matrix. Note the intercept column vector is assumed included.

y

The binary response.

prior

Prior for the logistic parameters.

delta

The ELBO difference tolerance for conversion.

maxiters

The maximum iterations to run if convergence is not achieved.

verbose

A diagnostics flag (off by default).

progressbar

A visual progress bar to indicate iterations (on by default).

Examples

Run this code
# \donttest{
  # Use Old Faithful data to show the effect of VB GMM Priors,
  # stopping on delta threshold
  # ------------------------------------------------------------------------------

  require(ggplot2)

  gen_path <- tempdir()
  data("faithful")
  X <- faithful
  P <- ncol(X)

  # ------------------------------------------------------------------------------
  # Plotting
  # ------------------------------------------------------------------------------

  #' Plots the GMM components with centroids
  #'
  #' @param i List index to place the plot
  #' @param gmm_result Results from the VB GMM run
  #' @param var_name Variable to hold the GMM hyperparameter name
  #' @param grid Grid element used in the plot file name
  #' @param fig_path Path to the directory where the plots should be stored
  #'
  #' @returns The ggplot figure (p)
  do_prior_plots <- function(i, gmm_result, var_name, grid, fig_path) {
    dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
    dd$cluster <- as.factor(dd$cluster)

    # The group means
    # ---------------------------------------------------------------------------
    mu <- as.data.frame( t(gmm_result$q_post$m) )

    # Plot the posterior mixture groups
    # ---------------------------------------------------------------------------
    cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
    p <- ggplot() +
      geom_point(dd, mapping=aes(x=eruptions, y=waiting, color=cluster)) +
      scale_color_discrete(cols, guide = 'none') +
      geom_point(mu, mapping=aes(x = eruptions, y = waiting), color="black",
                 pch=7, size=2) +
      stat_ellipse(dd, geom="polygon",
                   mapping=aes(x=eruptions, y=waiting, fill=cluster),
                   alpha=0.25)

    grids <- paste((grid[i,]), collapse = "_")
    ggsave(filename=paste0(var_name,"_",grids,".eps"), plot=p, path=fig_path,
           width=12, height=12, units="cm", dpi=600, create.dir = TRUE,
           device=cairo_ps)

    return(p)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha - same alpha value for each component and k=6.
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(x=c(1,30,70),
                           y=c(271,237,202))
  init <- "kmeans"
  k <- 6
  plots <- vector(mode="list", length=nrow(alpha_grid))

  for (i in 1:nrow(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[i,])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-9, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)

    plots[[i]] <- do_prior_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha - different alpha value for each component.
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(c1=c(1,1,183),
                           c2=c(1,92,92),
                           c3=c(1,183,198),
                           c4=c(1,183,50))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(alpha_grid))

  for (i in 1:nrow(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[i,]) # set most of the weight on one component
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)

    plots[[i]] <- do_prior_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart lambda for precision proportionality
  # ------------------------------------------------------------------------------
  lambda_grid <- data.frame(c1=c(0.1,0.9),
                            c2=c(0.1,0.9),
                            c3=c(0.1,0.9),
                            c4=c(0.1,0.9))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(lambda_grid))

  for (i in 1:nrow(lambda_grid)) {
    prior <- list(
      beta = as.numeric(lambda_grid[i,])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    plots[[i]] <- do_prior_plots(i, gmm_result, "lambda", lambda_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart W0 (assuming simplest-case diagonal covariance matrix) & logW
  # ------------------------------------------------------------------------------

  w_grid <- data.frame(c1=c(0.001,2.001),
                       c2=c(0.001,2.001),
                       c3=c(0.001,2.001),
                       c4=c(0.001,2.001))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(w_grid))

  for (i in 1:nrow(w_grid)) {
    w0 = diag(w_grid[i,],P)
    prior <- list(
      W = w0,
      logW = -2*sum(log(diag(chol(w0))))
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    plots[[i]] <- do_prior_plots(i, gmm_result, "w", w_grid, gen_path)
  }
# }


Run the code above in your browser using DataLab