# \donttest{
n <- 50
p <- 3
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# causal distillation tree output
out <- causalDT(X, Y, Z)
# compute subgroup CATEs manually
group_cates <- estimate_group_cates(
out$student_fit$fit,
X = X[out$holdout_idxs, , drop = FALSE],
Y = Y[out$holdout_idxs],
Z = Z[out$holdout_idxs]
)
all.equal(out$estimate, group_cates)
# }
Run the code above in your browser using DataLab