Learn R Programming

policytree (version 1.0)

predict.multi_causal_forest: Predict with multi_causal_forest

Description

Computes estimates of \(\tau_a(x)\)

Usage

# S3 method for multi_causal_forest
predict(object, newdata = NULL, ...)

Arguments

object

The trained forest.

newdata

Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order.

...

Additional arguments passed to grf::predict.causal_forest.

Value

List containing matrix of predictions and other estimates (debiased error, etc.) for each treatment.

Examples

Run this code
# NOT RUN {
# Train a multi causal forest.
n <- 250
p <- 10
d <- 3
X <- matrix(rnorm(n * p), n, p)
W <- sample(c("A", "B", "C"), n, replace = TRUE)
Y <- X[, 1] + X[, 2] * (W == "B") + X[, 3] * (W == "C") + runif(n)
multi.forest <- multi_causal_forest(X = X, Y = Y, W = W)

# Predict using the forest.
multi.forest.pred <- predict(multi.forest)
head(multi.forest.pred$predictions)
# }

Run the code above in your browser using DataLab