# Example: Continuous outcome and homogenuous treatment effect
n <- 50
p <- 3
X_control <- matrix(runif(n * p), ncol = p)
X_treat <- matrix(runif(n * p), ncol = p)
treatment <- rbinom(n, 1, 0.5)
tau <- 2
y <- X_control[, 1] + (0.5 - treatment) * tau + rnorm(n)
fit <- CausalHorseForest(
y = y,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treatment,
outcome_type = "continuous",
number_of_trees = 5,
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1
)
# \donttest{
## Example: Right-censored survival outcome
# Set data dimensions
n <- 100
p <- 1000
# Generate covariates
X <- matrix(runif(n * p), ncol = p)
X_treat <- X
treatment <- rbinom(n, 1, pnorm(X_treat[1, ] - 1/2))
# Generate true survival times depending on X and treatment
linpred <- X[, 1] - X[, 2] + (treatment - 0.5) * (1 + X[, 2] / 2 + X[, 3] / 3
+ X[, 4] / 4)
true_time <- linpred + rnorm(n, 0, 0.5)
# Generate censoring times
censor_time <- log(rexp(n, rate = 1 / 5))
# Observed times and event indicator
time_obs <- pmin(true_time, censor_time)
status <- as.numeric(true_time == time_obs)
# Estimate propensity score using HorseTrees
fit_prop <- HorseTrees(
y = treatment,
X_train = X,
outcome_type = "binary",
number_of_trees = 200,
N_post = 1000,
N_burn = 1000
)
# Retrieve estimated probability of treatment (propensity score)
propensity <- fit_prop$train_probabilities
# Combine propensity score with covariates for control forest
X_control <- cbind(propensity, X)
# Fit the Causal Horseshoe Forest for survival outcome
fit_surv <- CausalHorseForest(
y = time_obs,
status = status,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treatment,
outcome_type = "right-censored",
timescale = "log",
number_of_trees = 200,
k = 0.1,
N_post = 1000,
N_burn = 1000,
store_posterior_sample = TRUE
)
## Evaluate and summarize results
# Evaluate C-index if survival package is available
if (requireNamespace("survival", quietly = TRUE)) {
predicted_survtime <- fit_surv$train_predictions
cindex_result <- survival::concordance(survival::Surv(time_obs, status) ~ predicted_survtime)
c_index <- cindex_result$concordance
cat("C-index:", round(c_index, 3), "\n")
} else {
cat("Package 'survival' not available. Skipping C-index computation.\n")
}
# Compute posterior ATE samples
ate_samples <- rowMeans(fit_surv$train_predictions_sample_treat)
mean_ate <- mean(ate_samples)
ci_95 <- quantile(ate_samples, probs = c(0.025, 0.975))
cat("Posterior mean ATE:", round(mean_ate, 3), "\n")
cat("95% credible interval: [", round(ci_95[1], 3), ", ", round(ci_95[2], 3), "]\n", sep = "")
# Plot histogram of ATE samples
hist(
ate_samples,
breaks = 30,
col = "steelblue",
freq = FALSE,
border = "white",
xlab = "Average Treatment Effect (ATE)",
main = "Posterior distribution of ATE"
)
abline(v = mean_ate, col = "orange3", lwd = 2)
abline(v = ci_95, col = "orange3", lty = 2, lwd = 2)
abline(v = 1.541667, col = "darkred", lwd = 2)
legend(
"topright",
legend = c("Mean", "95% CI", "Truth"),
col = c("orange3", "orange3", "red"),
lty = c(1, 2, 1),
lwd = 2
)
## Plot individual CATE estimates
# Summarize posterior distribution per patient
posterior_matrix <- fit_surv$train_predictions_sample_treat
posterior_mean <- colMeans(posterior_matrix)
posterior_ci <- apply(posterior_matrix, 2, quantile, probs = c(0.025, 0.975))
df_cate <- data.frame(
mean = posterior_mean,
lower = posterior_ci[1, ],
upper = posterior_ci[2, ]
)
# Sort patients by posterior mean CATE
df_cate_sorted <- df_cate[order(df_cate$mean), ]
n_patients <- nrow(df_cate_sorted)
# Create the plot
plot(
x = df_cate_sorted$mean,
y = 1:n_patients,
type = "n",
xlab = "CATE per patient (95% credible interval)",
ylab = "Patient index (sorted)",
main = "Posterior CATE estimates",
xlim = range(df_cate_sorted$lower, df_cate_sorted$upper)
)
# Add CATE intervals
segments(
x0 = df_cate_sorted$lower,
x1 = df_cate_sorted$upper,
y0 = 1:n_patients,
y1 = 1:n_patients,
col = "steelblue"
)
# Add mean points
points(df_cate_sorted$mean, 1:n_patients, pch = 16, col = "orange3", lwd = 0.1)
# Add reference line at 0
abline(v = 0, col = "black", lwd = 2)
# }
Run the code above in your browser using DataLab