Learn R Programming

ShrinkageTrees (version 1.0.0)

CausalHorseForest: Causal Horseshoe Forests

Description

This function fits a (Bayesian) Causal Horseshoe Forest. It can be used for estimation of conditional average treatments effects of survival data given high-dimensional covariates. The outcome is decomposed in a prognostic part (control) and a treatment effect part. For both of these, we specify a Horseshoe Trees regression function.

Usage

CausalHorseForest(
  y,
  status = NULL,
  X_train_control,
  X_train_treat,
  treatment_indicator_train,
  X_test_control = NULL,
  X_test_treat = NULL,
  treatment_indicator_test = NULL,
  outcome_type = "continuous",
  timescale = "time",
  number_of_trees = 200,
  k = 0.1,
  power = 2,
  base = 0.95,
  p_grow = 0.4,
  p_prune = 0.4,
  nu = 3,
  q = 0.9,
  sigma = NULL,
  N_post = 5000,
  N_burn = 5000,
  delayed_proposal = 5,
  store_posterior_sample = FALSE,
  seed = NULL,
  verbose = TRUE
)

Value

A list containing:

train_predictions

Posterior mean predictions on training data (combined forest).

test_predictions

Posterior mean predictions on test data (combined forest).

train_predictions_control

Estimated control outcomes on training data.

test_predictions_control

Estimated control outcomes on test data.

train_predictions_treat

Estimated treatment effects on training data.

test_predictions_treat

Estimated treatment effects on test data.

sigma

Vector of posterior samples for the error standard deviation.

acceptance_ratio_control

Average acceptance ratio in control forest.

acceptance_ratio_treat

Average acceptance ratio in treatment forest.

train_predictions_sample_control

Matrix of posterior samples for control predictions (if store_posterior_sample = TRUE).

test_predictions_sample_control

Matrix of posterior samples for control predictions (if store_posterior_sample = TRUE).

train_predictions_sample_treat

Matrix of posterior samples for treatment effects (if store_posterior_sample = TRUE).

test_predictions_sample_treat

Matrix of posterior samples for treatment effects (if store_posterior_sample = TRUE).

Arguments

y

Outcome vector. For survival, represents follow-up times (can be on original or log scale depending on timescale).

status

Optional event indicator vector (1 = event occurred, 0 = censored). Required when outcome_type = "right-censored".

X_train_control

Covariate matrix for the control forest. Rows correspond to samples, columns to covariates.

X_train_treat

Covariate matrix for the treatment forest. Rows correspond to samples, columns to covariates.

treatment_indicator_train

Vector indicating treatment assignment for training samples (1 = treated, 0 = control).

X_test_control

Optional test covariate matrix for control forest. If NULL, defaults to column means of X_train_control.

X_test_treat

Optional test covariate matrix for treatment forest. If NULL, defaults to column means of X_train_treat.

treatment_indicator_test

Optional vector indicating treatment assignment for test samples.

outcome_type

Type of outcome: one of "continuous" or "right-censored". Default is "continuous".

timescale

For survival outcomes: either "time" (original time scale, log-transformed internally) or "log" (already log-transformed).

number_of_trees

Number of trees in each forest. Default is 200.

k

Horseshoe prior scale hyperparameter. Default is 0.1. Controls global-local shrinkage on step heights.

power

Power parameter for tree structure prior. Default is 2.0.

base

Base parameter for tree structure prior. Default is 0.95.

p_grow

Probability of proposing a grow move. Default is 0.4.

p_prune

Probability of proposing a prune move. Default is 0.4.

nu

Degrees of freedom for the error variance prior. Default is 3.

q

Quantile parameter for error variance prior. Default is 0.90.

sigma

Optional known standard deviation of the outcome. If NULL, estimated from data.

N_post

Number of posterior samples to store. Default is 5000.

N_burn

Number of burn-in iterations. Default is 5000.

delayed_proposal

Number of delayed iterations before proposal updates. Default is 5.

store_posterior_sample

Logical; whether to store posterior samples of predictions. Default is FALSE.

seed

Random seed for reproducibility. Default is NULL.

verbose

Logical; whether to print verbose output during sampling. Default is TRUE.

Details

The model separately regularizes the control and treatment trees using Horseshoe priors with global-local shrinkage on the step heights. This approach is designed for robust estimation of heterogeneous treatment effects in high-dimensional settings. It supports continuous and right-censored survival outcomes.

See Also

HorseTrees, ShrinkageTrees, CausalShrinkageForest

Examples

Run this code
# Example: Continuous outcome and homogenuous treatment effect
n <- 50
p <- 3
X_control <- matrix(runif(n * p), ncol = p)
X_treat <- matrix(runif(n * p), ncol = p)
treatment <- rbinom(n, 1, 0.5)
tau <- 2
y <- X_control[, 1] + (0.5 - treatment) * tau + rnorm(n)

fit <- CausalHorseForest(
  y = y,
  X_train_control = X_control,
  X_train_treat = X_treat,
  treatment_indicator_train = treatment,
  outcome_type = "continuous",
  number_of_trees = 5,
  N_post = 10,
  N_burn = 5,
  store_posterior_sample = TRUE,
  verbose = FALSE,
  seed = 1
)

# \donttest{
## Example: Right-censored survival outcome
# Set data dimensions
n <- 100
p <- 1000

# Generate covariates
X <- matrix(runif(n * p), ncol = p)
X_treat <- X
treatment <- rbinom(n, 1, pnorm(X_treat[1, ] - 1/2))

# Generate true survival times depending on X and treatment
linpred <- X[, 1] - X[, 2] + (treatment - 0.5) * (1 + X[, 2] / 2 + X[, 3] / 3 
                                                  + X[, 4] / 4)
true_time <- linpred + rnorm(n, 0, 0.5)

# Generate censoring times
censor_time <- log(rexp(n, rate = 1 / 5))

# Observed times and event indicator
time_obs <- pmin(true_time, censor_time)
status <- as.numeric(true_time == time_obs)

# Estimate propensity score using HorseTrees
fit_prop <- HorseTrees(
  y = treatment,
  X_train = X,
  outcome_type = "binary",
  number_of_trees = 200,
  N_post = 1000,
  N_burn = 1000
)

# Retrieve estimated probability of treatment (propensity score)
propensity <- fit_prop$train_probabilities

# Combine propensity score with covariates for control forest
X_control <- cbind(propensity, X)

# Fit the Causal Horseshoe Forest for survival outcome
fit_surv <- CausalHorseForest(
  y = time_obs,
  status = status,
  X_train_control = X_control,
  X_train_treat = X_treat,
  treatment_indicator_train = treatment,
  outcome_type = "right-censored",
  timescale = "log",
  number_of_trees = 200,
  k = 0.1,
  N_post = 1000,
  N_burn = 1000,
  store_posterior_sample = TRUE
)

## Evaluate and summarize results

# Evaluate C-index if survival package is available
if (requireNamespace("survival", quietly = TRUE)) {
  predicted_survtime <- fit_surv$train_predictions
  cindex_result <- survival::concordance(survival::Surv(time_obs, status) ~ predicted_survtime)
  c_index <- cindex_result$concordance
  cat("C-index:", round(c_index, 3), "\n")
} else {
  cat("Package 'survival' not available. Skipping C-index computation.\n")
}

# Compute posterior ATE samples
ate_samples <- rowMeans(fit_surv$train_predictions_sample_treat)
mean_ate <- mean(ate_samples)
ci_95 <- quantile(ate_samples, probs = c(0.025, 0.975))

cat("Posterior mean ATE:", round(mean_ate, 3), "\n")
cat("95% credible interval: [", round(ci_95[1], 3), ", ", round(ci_95[2], 3), "]\n", sep = "")

# Plot histogram of ATE samples
hist(
  ate_samples,
  breaks = 30,
  col = "steelblue",
  freq = FALSE,
  border = "white",
  xlab = "Average Treatment Effect (ATE)",
  main = "Posterior distribution of ATE"
)
abline(v = mean_ate, col = "orange3", lwd = 2)
abline(v = ci_95, col = "orange3", lty = 2, lwd = 2)
abline(v = 1.541667, col = "darkred", lwd = 2)
legend(
  "topright",
  legend = c("Mean", "95% CI", "Truth"),
  col = c("orange3", "orange3", "red"),
  lty = c(1, 2, 1),
  lwd = 2
)

## Plot individual CATE estimates

# Summarize posterior distribution per patient
posterior_matrix <- fit_surv$train_predictions_sample_treat
posterior_mean <- colMeans(posterior_matrix)
posterior_ci <- apply(posterior_matrix, 2, quantile, probs = c(0.025, 0.975))

df_cate <- data.frame(
  mean = posterior_mean,
  lower = posterior_ci[1, ],
  upper = posterior_ci[2, ]
)

# Sort patients by posterior mean CATE
df_cate_sorted <- df_cate[order(df_cate$mean), ]
n_patients <- nrow(df_cate_sorted)

# Create the plot
plot(
  x = df_cate_sorted$mean,
  y = 1:n_patients,
  type = "n",
  xlab = "CATE per patient (95% credible interval)",
  ylab = "Patient index (sorted)",
  main = "Posterior CATE estimates",
  xlim = range(df_cate_sorted$lower, df_cate_sorted$upper)
)

# Add CATE intervals
segments(
  x0 = df_cate_sorted$lower,
  x1 = df_cate_sorted$upper,
  y0 = 1:n_patients,
  y1 = 1:n_patients,
  col = "steelblue"
)

# Add mean points
points(df_cate_sorted$mean, 1:n_patients, pch = 16, col = "orange3", lwd = 0.1)

# Add reference line at 0
abline(v = 0, col = "black", lwd = 2)

# }

Run the code above in your browser using DataLab