## Generate data.
set.seed(1986)
n <- 1000
k <- 3
X <- matrix(rnorm(n * k), ncol = k)
colnames(X) <- paste0("x", seq_len(k))
D <- rbinom(n, size = 1, prob = 0.5)
mu0 <- 0.5 * X[, 1]
mu1 <- 0.5 * X[, 1] + X[, 2]
Y <- mu0 + D * (mu1 - mu0) + rnorm(n)
## Split the sample.
splits <- sample_split(length(Y), training_frac = 0.5)
training_idx <- splits$training_idx
honest_idx <- splits$honest_idx
Y_tr <- Y[training_idx]
D_tr <- D[training_idx]
X_tr <- X[training_idx, ]
Y_hon <- Y[honest_idx]
D_hon <- D[honest_idx]
X_hon <- X[honest_idx, ]
## Construct a tree using training sample.
library(rpart)
tree <- rpart(Y ~ ., data = data.frame("Y" = Y_tr, X_tr), maxdepth = 2)
## Estimate GATEs in each node (internal and terminal) using honest sample.
new_tree <- estimate_rpart(tree, Y_hon, D_hon, X_hon, method = "raw")
new_tree$tree
Run the code above in your browser using DataLab