# Example: Continuous outcome with ShrinkageTrees, two priors
n <- 50
p <- 3
X <- matrix(runif(n * p), ncol = p)
X_test <- matrix(runif(n * p), ncol = p)
y <- X[, 1] + rnorm(n)
# Fit ShrinkageTrees with standard horseshoe prior
fit_horseshoe <- ShrinkageTrees(y = y,
X_train = X,
X_test = X_test,
outcome_type = "continuous",
number_of_trees = 5,
prior_type = "horseshoe",
local_hp = 0.1 / sqrt(5),
global_hp = 0.1 / sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1)
# Fit ShrinkageTrees with half-Cauchy prior
fit_halfcauchy <- ShrinkageTrees(y = y,
X_train = X,
X_test = X_test,
outcome_type = "continuous",
number_of_trees = 5,
prior_type = "half-cauchy",
local_hp = 1 / sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1)
# Posterior mean predictions
pred_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample)
pred_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample)
# Posteriors of the mean (global average prediction)
post_mean_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample)
post_mean_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample)
# Posterior mean prediction averages
mean_pred_horseshoe <- mean(post_mean_horseshoe)
mean_pred_halfcauchy <- mean(post_mean_halfcauchy)
Run the code above in your browser using DataLab