Learn R Programming

batchmix (version 2.2.1)

predictFromMultipleChains: Predict from multiple MCMC chains

Description

Applies a burn in to and finds a point estimate by combining multiple chains of ``callMDI``.

Usage

predictFromMultipleChains(
  mcmc_outputs,
  burn,
  point_estimate_method = "median",
  chains_already_processed = FALSE
)

Value

A named list of quantities related to prediction/clustering:

* ``allocation_probability``: List with an $(N x K)$ matrix if the model is semi-supervised. The point estimate of the allocation probabilities for each data point to each class.

* ``prob``: $N$ vector of the point estimate of the probability of being allocated to the class with the highest probability.

* ``pred``: $N$ vector of the predicted class for each sample. If the model is unsupervised then the ``salso`` function from Dahl et al. (2021) is used on the sampled partitions using the default settings.

* ``samples``: List of sampled allocations for each view. Columns correspond to items being clustered, rows to MCMC samples.

Arguments

mcmc_outputs

Output from ``runMCMCChains``

burn

The number of MCMC samples to drop as part of a burn in.

point_estimate_method

Summary statistic used to define the point estimate. Must be ``'mean'`` or ``'median'``. ``'median'`` is the default.

chains_already_processed

Logical indicating if the chains have already had a burn-in applied.

Examples

Run this code
# \donttest{
# Data dimensions
N <- 600
P <- 4
K <- 5
B <- 7

# Generating model parameters
mean_dist <- 2.25
batch_dist <- 0.3
group_means <- seq(1, K) * mean_dist
batch_shift <- rnorm(B, mean = batch_dist, sd = batch_dist)
std_dev <- rep(2, K)
batch_var <- rep(1.2, B)
group_weights <- rep(1 / K, K)
batch_weights <- rep(1 / B, B)
dfs <- c(4, 7, 15, 60, 120)

my_data <- generateBatchData(
  N,
  P,
  group_means,
  std_dev,
  batch_shift,
  batch_var,
  group_weights,
  batch_weights,
  type = "MVT",
  group_dfs = dfs
)


X <- my_data$observed_data

true_labels <- my_data$group_IDs
fixed <- my_data$fixed
batch_vec <- my_data$batch_IDs

alpha <- 1
initial_labels <- generateInitialLabels(alpha, K, fixed, true_labels)

# Sampling parameters
R <- 1000
thin <- 25
burn <- 100
n_chains <- 2

# Density choice
type <- "MVT"

# MCMC samples and BIC vector
mcmc_outputs <- runMCMCChains(
  X,
  n_chains,
  R,
  thin,
  batch_vec,
  type,
  initial_labels = initial_labels,
  fixed = fixed
)
ensemble_mod <- predictFromMultipleChains(mcmc_outputs, burn)
# }

Run the code above in your browser using DataLab