Learn R Programming

shapr (version 0.2.2)

sample_ctree: Sample ctree variables from a given conditional inference tree

Description

Sample ctree variables from a given conditional inference tree

Usage

sample_ctree(tree, n_samples, x_test, x_train, p, sample)

Value

data.table with n_samples (conditional) Gaussian samples

Arguments

tree

List. Contains tree which is an object of type ctree built from the party package. Also contains given_ind, the features to condition upon.

n_samples

Numeric. Indicates how many samples to use for MCMC.

x_test

Matrix, data.frame or data.table with the features of the observation whose predictions ought to be explained (test data). Dimension 1xp or px1.

x_train

Matrix, data.frame or data.table with training data.

p

Positive integer. The number of features.

sample

Boolean. True indicates that the method samples from the terminal node of the tree whereas False indicates that the method takes all the observations if it is less than n_samples.

Author

Annabelle Redelmeier

Examples

Run this code
if (requireNamespace("MASS", quietly = TRUE) & requireNamespace("party", quietly = TRUE)) {
  m <- 10
  n <- 40
  n_samples <- 50
  mu <- rep(1, m)
  cov_mat <- cov(matrix(rnorm(n * m), n, m))
  x_train <- data.table::data.table(MASS::mvrnorm(n, mu, cov_mat))
  x_test <- MASS::mvrnorm(1, mu, cov_mat)
  x_test_dt <- data.table::setDT(as.list(x_test))
  given_ind <- c(4, 7)
  dependent_ind <- (1:dim(x_train)[2])[-given_ind]
  x <- x_train[, given_ind, with = FALSE]
  y <- x_train[, dependent_ind, with = FALSE]
  df <- data.table::data.table(cbind(y, x))
  colnames(df) <- c(paste0("Y", 1:ncol(y)), paste0("V", given_ind))
  ynam <- paste0("Y", 1:ncol(y))
  fmla <- as.formula(paste(paste(ynam, collapse = "+"), "~ ."))
  datact <- party::ctree(fmla, data = df, controls = party::ctree_control(
    minbucket = 7,
    mincriterion = 0.95
  ))
  tree <- list(tree = datact, given_ind = given_ind, dependent_ind = dependent_ind)
  shapr:::sample_ctree(
    tree = tree, n_samples = n_samples, x_test = x_test_dt, x_train = x_train,
    p = length(x_test), sample = TRUE
  )
}

Run the code above in your browser using DataLab