Learn R Programming

VBphenoR (version 1.1.0)

vb_gmm_cavi: Main algorithm function for the VB CAVI GMM

Description

Main algorithm function for the VB CAVI GMM

Usage

vb_gmm_cavi(
  X,
  k,
  prior = NULL,
  delta = 1e-06,
  maxiters = 5000,
  init = "kmeans",
  initParams = NULL,
  stopIfELBOReverse = FALSE,
  verbose = FALSE,
  logDiagnostics = FALSE,
  logFilename = "vb_gmm_log.txt",
  progressbar = TRUE
)

Value

A list containing:

  • error - An error message if convergence failed or the number of iterations to achieve convergence.

  • z_post - A vector of posterior cluster mappings.

  • q_post - A list of the fitted posterior Q family. q_post includes:

    • alpha - The Dirichlet prior for the mixing coefficient

    • lambda - The proportionality of the precision for the Normal-Wishart prior (called 'beta' in Bishop)

    • m - The mean vector for the Normal part of the Normal-Wishart prior

    • v - The degrees of freedom of the Wishart gamma ensuring it is well defined

    • U - The W hyperparameter of the Wishart conjugate prior for the precision of the mixtures. k sets of D x D symmetric, positive definite matrices.

    • logW - The logW term used to calculate the expectation of the mixture component precision. A vector of k.

    • R - The responsibilities. An n x k matrix.

    • logR - The log of responsibilities. An n x k matrix.

  • elbo - A vector of the ELBO at each iteration.

  • elbo_d - A vector of ELBO differences between each iteration.

Arguments

X

n x p data matrix (or data frame that will be converted to a matrix).

k

guess for the number of mixture components.

prior

Prior for the GMM parameters.

delta

change in ELBO that triggers algorithm stopping.

maxiters

maximum iterations to run if delta does not stop the algorithm already.

init

initialize the clusters c("random", "kmeans", "dbscan").

initParams

initialization parameters for dbscan. NULL if dbscan not selected for init.

stopIfELBOReverse

stop the run if the ELBO at iteration t is detected to have reversed from iteration t-1.

verbose

print out information per iteration to track progress in case of long-running experiments.

logDiagnostics

log detailed diagnostics. If TRUE, a diagnostics RDS file will be created using the path specified in logFilename.

logFilename

the filename of the diagnostics log.

progressbar

A visual progress bar to indicate iterations (on by default).

Examples

Run this code
# \donttest{
  # Use Old Faithful data to show the effect of VB GMM Priors,
  # stopping on delta threshold
  # ------------------------------------------------------------------------------

  require(ggplot2)

  gen_path <- tempdir()
  data("faithful")
  X <- faithful
  P <- ncol(X)

  # ------------------------------------------------------------------------------
  # Plotting
  # ------------------------------------------------------------------------------

  #' Plots the GMM components with centroids
  #'
  #' @param i List index to place the plot
  #' @param gmm_result Results from the VB GMM run
  #' @param var_name Variable to hold the GMM hyperparameter name
  #' @param grid Grid element used in the plot file name
  #' @param fig_path Path to the directory where the plots should be stored
  #'
  #' @returns The ggplot figure (p)
  do_prior_plots <- function(i, gmm_result, var_name, grid, fig_path) {
    dd <- as.data.frame(cbind(X, cluster = gmm_result$z_post))
    dd$cluster <- as.factor(dd$cluster)

    # The group means
    # ---------------------------------------------------------------------------
    mu <- as.data.frame( t(gmm_result$q_post$m) )

    # Plot the posterior mixture groups
    # ---------------------------------------------------------------------------
    cols <- c("#1170AA", "#55AD89", "#EF6F6A", "#D3A333", "#5FEFE8", "#11F444")
    p <- ggplot() +
      geom_point(dd, mapping=aes(x=eruptions, y=waiting, color=cluster)) +
      scale_color_discrete(cols, guide = 'none') +
      geom_point(mu, mapping=aes(x = eruptions, y = waiting), color="black",
                 pch=7, size=2) +
      stat_ellipse(dd, geom="polygon",
                   mapping=aes(x=eruptions, y=waiting, fill=cluster),
                   alpha=0.25)

    grids <- paste((grid[i,]), collapse = "_")
    ggsave(filename=paste0(var_name,"_",grids,".eps"), plot=p, path=fig_path,
           width=12, height=12, units="cm", dpi=600, create.dir = TRUE,
           device=cairo_ps)

    return(p)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha - same alpha value for each component and k=6.
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(x=c(1,30,70),
                           y=c(271,237,202))
  init <- "kmeans"
  k <- 6
  plots <- vector(mode="list", length=nrow(alpha_grid))

  for (i in 1:nrow(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[i,])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-9, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)

    plots[[i]] <- do_prior_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Dirichlet alpha - different alpha value for each component.
  # ------------------------------------------------------------------------------
  alpha_grid <- data.frame(c1=c(1,1,183),
                           c2=c(1,92,92),
                           c3=c(1,183,198),
                           c4=c(1,183,50))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(alpha_grid))

  for (i in 1:nrow(alpha_grid)) {
    prior <- list(
      alpha = as.integer(alpha_grid[i,]) # set most of the weight on one component
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)

    plots[[i]] <- do_prior_plots(i, gmm_result, "alpha", alpha_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart lambda for precision proportionality
  # ------------------------------------------------------------------------------
  lambda_grid <- data.frame(c1=c(0.1,0.9),
                            c2=c(0.1,0.9),
                            c3=c(0.1,0.9),
                            c4=c(0.1,0.9))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(lambda_grid))

  for (i in 1:nrow(lambda_grid)) {
    prior <- list(
      beta = as.numeric(lambda_grid[i,])
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    plots[[i]] <- do_prior_plots(i, gmm_result, "lambda", lambda_grid, gen_path)
  }

  # ------------------------------------------------------------------------------
  # Normal-Wishart W0 (assuming simplest-case diagonal covariance matrix) & logW
  # ------------------------------------------------------------------------------

  w_grid <- data.frame(c1=c(0.001,2.001),
                       c2=c(0.001,2.001),
                       c3=c(0.001,2.001),
                       c4=c(0.001,2.001))
  init <- "kmeans"
  k <- 4
  plots <- vector(mode="list", length=nrow(w_grid))

  for (i in 1:nrow(w_grid)) {
    w0 = diag(w_grid[i,],P)
    prior <- list(
      W = w0,
      logW = -2*sum(log(diag(chol(w0))))
    )

    gmm_result <- vb_gmm_cavi(X=X, k=k, prior=prior, delta=1e-8, init=init,
                              verbose=FALSE, logDiagnostics=FALSE)
    plots[[i]] <- do_prior_plots(i, gmm_result, "w", w_grid, gen_path)
  }
# }


Run the code above in your browser using DataLab