Learn R Programming

ShrinkageTrees (version 1.0.0)

CausalShrinkageForest: General Causal Shrinkage Forests

Description

Fits a (Bayesian) Causal Shrinkage Forest model for estimating heterogeneous treatment effects. This function generalizes CausalHorseForest by allowing flexible global-local shrinkage priors on the step heights in both the control and treatment forests. It supports continuous and right-censored survival outcomes.

Usage

CausalShrinkageForest(
  y,
  status = NULL,
  X_train_control,
  X_train_treat,
  treatment_indicator_train,
  X_test_control = NULL,
  X_test_treat = NULL,
  treatment_indicator_test = NULL,
  outcome_type = "continuous",
  timescale = "time",
  number_of_trees_control = 200,
  number_of_trees_treat = 200,
  prior_type_control = "horseshoe",
  prior_type_treat = "horseshoe",
  local_hp_control,
  local_hp_treat,
  global_hp_control = NULL,
  global_hp_treat = NULL,
  power = 2,
  base = 0.95,
  p_grow = 0.4,
  p_prune = 0.4,
  nu = 3,
  q = 0.9,
  sigma = NULL,
  N_post = 5000,
  N_burn = 5000,
  delayed_proposal = 5,
  store_posterior_sample = FALSE,
  seed = NULL,
  verbose = TRUE
)

Value

A list containing:

train_predictions

Posterior mean predictions on training data (combined forest).

test_predictions

Posterior mean predictions on test data (combined forest).

train_predictions_control

Estimated control outcomes on training data.

test_predictions_control

Estimated control outcomes on test data.

train_predictions_treat

Estimated treatment effects on training data.

test_predictions_treat

Estimated treatment effects on test data.

sigma

Vector of posterior samples for the error standard deviation.

acceptance_ratio_control

Average acceptance ratio in control forest.

acceptance_ratio_treat

Average acceptance ratio in treatment forest.

train_predictions_sample_control

Matrix of posterior samples for control predictions (if store_posterior_sample = TRUE).

test_predictions_sample_control

Matrix of posterior samples for control predictions (if store_posterior_sample = TRUE).

train_predictions_sample_treat

Matrix of posterior samples for treatment effects (if store_posterior_sample = TRUE).

test_predictions_sample_treat

Matrix of posterior samples for treatment effects (if store_posterior_sample = TRUE).

Arguments

y

Outcome vector. Numeric. Represents continuous outcomes or follow-up times.

status

Optional event indicator vector (1 = event occurred, 0 = censored). Required when outcome_type = "right-censored".

X_train_control

Covariate matrix for the control forest. Rows correspond to samples, columns to covariates.

X_train_treat

Covariate matrix for the treatment forest.

treatment_indicator_train

Vector indicating treatment assignment for training samples (1 = treated, 0 = control).

X_test_control

Optional covariate matrix for control forest test data. Defaults to column means of X_train_control if NULL.

X_test_treat

Optional covariate matrix for treatment forest test data. Defaults to column means of X_train_treat if NULL.

treatment_indicator_test

Optional vector indicating treatment assignment for test data.

outcome_type

Type of outcome: one of "continuous" or "right-censored". Default is "continuous".

timescale

For survival outcomes: either "time" (original scale, log-transformed internally) or "log" (already log-transformed). Default is "time".

number_of_trees_control

Number of trees in the control forest. Default is 200.

number_of_trees_treat

Number of trees in the treatment forest. Default is 200.

prior_type_control

Type of prior on control forest step heights. One of "horseshoe", "horseshoe_fw", "horseshoe_EB", or "half-cauchy". Default is "horseshoe".

prior_type_treat

Type of prior on treatment forest step heights. Same options as prior_type_control.

local_hp_control

Local hyperparameter controlling shrinkage on individual steps (control forest). Required for all prior types.

local_hp_treat

Local hyperparameter for treatment forest.

global_hp_control

Global hyperparameter for control forest. Required for horseshoe-type priors; ignored for "half-cauchy".

global_hp_treat

Global hyperparameter for treatment forest.

power

Power parameter for tree structure prior. Default is 2.0.

base

Base parameter for tree structure prior. Default is 0.95.

p_grow

Probability of proposing a grow move. Default is 0.4.

p_prune

Probability of proposing a prune move. Default is 0.4.

nu

Degrees of freedom for the error variance prior. Default is 3.

q

Quantile parameter for error variance prior. Default is 0.90.

sigma

Optional known standard deviation of the outcome. If NULL, estimated from data.

N_post

Number of posterior samples to store. Default is 5000.

N_burn

Number of burn-in iterations. Default is 5000.

delayed_proposal

Number of delayed iterations before proposal updates. Default is 5.

store_posterior_sample

Logical; whether to store posterior samples of predictions. Default is FALSE.

seed

Random seed for reproducibility. Default is NULL.

verbose

Logical; whether to print verbose output. Default is TRUE.

Details

This function is a flexible generalization of CausalHorseForest. The Causal Shrinkage Forest model decomposes the outcome into a prognostic (control) and a treatment effect part. Each part is modeled by its own shrinkage tree ensemble, with separate flexible global-local shrinkage priors. It is particularly useful for estimating heterogeneous treatment effects in high-dimensional settings.

The horseshoe prior is the fully Bayesian global-local shrinkage prior, where both the global and local shrinkage parameters are assigned half-Cauchy distributions with scale hyperparameters global_hp and local_hp, respectively. The global shrinkage parameter is defined separately for each tree, allowing adaptive regularization per tree.

The horseshoe_fw prior (forest-wide horseshoe) is similar to horseshoe, except that the global shrinkage parameter is shared across all trees in the forest simultaneously.

The horseshoe_EB prior is an empirical Bayes variant of the horseshoe prior. Here, the global shrinkage parameter (\(\tau\)) is not assigned a prior distribution but instead must be specified directly using global_hp, while local shrinkage parameters still follow half-Cauchy priors. Note: \(\tau\) must be provided by the user; it is not estimated by the software.

The half-cauchy prior considers only local shrinkage and does not include a global shrinkage component. It places a half-Cauchy prior on each local shrinkage parameter with scale hyperparameter local_hp.

See Also

CausalHorseForest, ShrinkageTrees, HorseTrees

Examples

Run this code
# Example: Continuous outcome, homogenuous treatment effect, two priors
n <- 50
p <- 3
X <- matrix(runif(n * p), ncol = p)
X_treat <- X_control <- X
treat <- rbinom(n, 1, X[,1])
tau <- 2
y <- X[, 1] + (0.5 - treat) * tau + rnorm(n)

# Fit a standard Causal Horseshoe Forest
fit_horseshoe <- CausalShrinkageForest(y = y,
                                       X_train_control = X_control,
                                       X_train_treat = X_treat,
                                       treatment_indicator_train = treat,
                                       outcome_type = "continuous",
                                       number_of_trees_treat = 5,
                                       number_of_trees_control = 5,
                                       prior_type_control = "horseshoe",
                                       prior_type_treat = "horseshoe",
                                       local_hp_control = 0.1/sqrt(5),
                                       local_hp_treat = 0.1/sqrt(5),
                                       global_hp_control = 0.1/sqrt(5),
                                       global_hp_treat = 0.1/sqrt(5),
                                       N_post = 10,
                                       N_burn = 5,
                                       store_posterior_sample = TRUE,
                                       verbose = FALSE,
                                       seed = 1
)

# Fit a Causal Shrinkage Forest with half-cauchy prior
fit_halfcauchy <- CausalShrinkageForest(y = y,
                                        X_train_control = X_control,
                                        X_train_treat = X_treat,
                                        treatment_indicator_train = treat,
                                        outcome_type = "continuous",
                                        number_of_trees_treat = 5,
                                        number_of_trees_control = 5,
                                        prior_type_control = "half-cauchy",
                                        prior_type_treat = "half-cauchy",
                                        local_hp_control = 1/sqrt(5),
                                        local_hp_treat = 1/sqrt(5),
                                        N_post = 10,
                                        N_burn = 5,
                                        store_posterior_sample = TRUE,
                                        verbose = FALSE,
                                        seed = 1
)

# Posterior mean CATEs
CATE_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample_treat)
CATE_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample_treat)

# Posteriors of the ATE
post_ATE_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample_treat)
post_ATE_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample_treat)

# Posterior mean ATE
ATE_horseshoe <- mean(post_ATE_horseshoe)
ATE_halfcauchy <- mean(post_ATE_halfcauchy)


Run the code above in your browser using DataLab