Learn R Programming

ShrinkageTrees (version 1.0.0)

HorseTrees: Horseshoe Regression Trees (HorseTrees)

Description

Fits a Bayesian Horseshoe Trees model with a single learner. Implements regularization on the step heights using a global-local Horseshoe prior, controlled via the parameter k. Supports continuous, binary, and right-censored (survival) outcomes.

Usage

HorseTrees(
  y,
  status = NULL,
  X_train,
  X_test = NULL,
  outcome_type = "continuous",
  timescale = "time",
  number_of_trees = 200,
  k = 0.1,
  power = 2,
  base = 0.95,
  p_grow = 0.4,
  p_prune = 0.4,
  nu = 3,
  q = 0.9,
  sigma = NULL,
  N_post = 1000,
  N_burn = 1000,
  delayed_proposal = 5,
  store_posterior_sample = TRUE,
  seed = NULL,
  verbose = TRUE
)

Value

A named list with the following elements:

train_predictions

Vector of posterior mean predictions on the training data.

test_predictions

Vector of posterior mean predictions on the test data (or on mean covariate vector if X_test not provided).

sigma

Vector of posterior samples of the error variance.

acceptance_ratio

Average acceptance ratio across trees during sampling.

train_predictions_sample

Matrix of posterior samples of training predictions (iterations in rows, observations in columns). Present only if store_posterior_sample = TRUE.

test_predictions_sample

Matrix of posterior samples of test predictions. Present only if store_posterior_sample = TRUE.

train_probabilities

Vector of posterior mean probabilities on the training data (only for outcome_type = "binary").

test_probabilities

Vector of posterior mean probabilities on the test data (only for outcome_type = "binary").

train_probabilities_sample

Matrix of posterior samples of training probabilities (only for outcome_type = "binary" and if store_posterior_sample = TRUE).

test_probabilities_sample

Matrix of posterior samples of test probabilities (only for outcome_type = "binary" and if store_posterior_sample = TRUE).

Arguments

y

Outcome vector. Numeric. Can represent continuous outcomes, binary outcomes (0/1), or follow-up times for survival data.

status

Optional censoring indicator vector (1 = event occurred, 0 = censored). Required if outcome_type = "right-censored".

X_train

Covariate matrix for training. Each row corresponds to an observation, and each column to a covariate.

X_test

Optional covariate matrix for test data. If NULL, defaults to the mean of the training covariates.

outcome_type

Type of outcome. One of "continuous", "binary", or "right-censored".

timescale

Indicates the scale of follow-up times. Options are "time" (nonnegative follow-up times, will be log-transformed internally) or "log" (already log-transformed). Only used when outcome_type = "right-censored".

number_of_trees

Number of trees in the ensemble. Default is 200.

k

Horseshoe scale hyperparameter (default 0.1). This parameter controls the overall level of shrinkage by setting the scale for both global and local shrinkage components. The local and global hyperparameters are parameterized as \(\alpha = \frac{k}{\sqrt{\mathrm{number\_of\_trees}}}\) to ensure adaptive regularization across trees.

power

Power parameter for tree structure prior. Default is 2.0.

base

Base parameter for tree structure prior. Default is 0.95.

p_grow

Probability of proposing a grow move. Default is 0.4.

p_prune

Probability of proposing a prune move. Default is 0.4.

nu

Degrees of freedom for the error distribution prior. Default is 3.

q

Quantile hyperparameter for the error variance prior. Default is 0.90.

sigma

Optional known value for error standard deviation. If NULL, estimated from data.

N_post

Number of posterior samples to store. Default is 1000.

N_burn

Number of burn-in iterations. Default is 1000.

delayed_proposal

Number of delayed iterations before proposal. Only for reversible updates. Default is 5.

store_posterior_sample

Logical; whether to store posterior samples for each iteration. Default is TRUE.

seed

Random seed for reproducibility.

verbose

Logical; whether to print verbose output. Default is TRUE.

Details

For continuous outcomes, the model centers and optionally standardizes the outcome using a prior guess of the standard deviation. For binary outcomes, the function uses a probit link formulation. For right-censored outcomes (survival data), the function can handle follow-up times either on the original time scale or log-transformed. Generalized implementation with multiple prior possibilities is given by ShrinkageTrees.

See Also

ShrinkageTrees, CausalHorseForest, CausalShrinkageForest

Examples

Run this code
# Minimal example: continuous outcome
n <- 25
p <- 5
X <- matrix(rnorm(n * p), ncol = p)
y <- X[, 1] + rnorm(n)
fit1 <- HorseTrees(y = y, X_train = X, outcome_type = "continuous", 
                   number_of_trees = 5, N_post = 75, N_burn = 25, 
                   verbose = FALSE)

# Minimal example: binary outcome
X <- matrix(rnorm(n * p), ncol = p)
y <- ifelse(X[, 1] + rnorm(n) > 0, 1, 0)
fit2 <- HorseTrees(y = y, X_train = X, outcome_type = "binary", 
                   number_of_trees = 5, N_post = 75, N_burn = 25, 
                   verbose = FALSE)

# Minimal example: right-censored outcome
X <- matrix(rnorm(n * p), ncol = p)
time <- rexp(n, rate = 0.1)
status <- rbinom(n, 1, 0.7)
fit3 <- HorseTrees(y = time, status = status, X_train = X, 
                   outcome_type = "right-censored", number_of_trees = 5, 
                   N_post = 75, N_burn = 25, verbose = FALSE)

# Larger continuous example (not run automatically)
# \donttest{
n <- 100
p <- 100
X <- matrix(rnorm(100 * p), ncol = p)
X_test <- matrix(rnorm(50 * p), ncol = p)
y <- X[, 1] + X[, 2] - X[, 3] + rnorm(100, sd = 0.5)

fit4 <- HorseTrees(y = y,
                   X_train = X,
                   X_test = X_test,
                   outcome_type = "continuous",
                   number_of_trees = 200,
                   N_post = 2500,
                   N_burn = 2500,
                   store_posterior_sample = TRUE,
                   verbose = TRUE)

plot(fit4$sigma, type = "l", ylab = expression(sigma),
     xlab = "Iteration", main = "Sigma traceplot")

hist(fit4$train_predictions_sample[, 1],
     main = "Posterior distribution of prediction outcome individual 1",
     xlab = "Prediction", breaks = 20)
# }
                       

Run the code above in your browser using DataLab