Learn R Programming

policytree (version 1.0.4)

predict.policy_tree: Predict method for policy_tree

Description

Predict values based on fitted policy_tree object.

Usage

# S3 method for policy_tree
predict(object, newdata, type = c("action.id", "node.id"), ...)

Arguments

object

policy_tree object

newdata

A data frame with features

type

The type of prediction required, "action.id" is the action id and "node.id" is the integer id of the leaf node the sample falls into. Default is "action.id".

...

Additional arguments (currently ignored).

Value

A vector of predictions. For type = "action.id" each element is an integer from 1 to d where d is the number of columns in the reward matrix. For type = "node.id" each element is an integer corresponding to the node the sample falls into (level-ordered).

Examples

Run this code
# NOT RUN {
# Fit a depth two tree on doubly robust treatment effect estimates
# from a causal forest.
n <- 10000
p <- 5
X <- round(matrix(rnorm(n * p), n, p), 2)
W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
Y <- X[, 3] + W * tau + rnorm(n)
c.forest <- grf::causal_forest(X, Y, W)
dr.scores <- double_robust_scores(c.forest)

tree <- policy_tree(X, dr.scores, 2)
tree

# Predict treatment assignment.
predicted <- predict(tree, X)

plot(X[, 1], X[, 2], col = predicted)
legend("topright", c("control", "treat"), col = c(1, 2), pch = 19)
abline(0, -1, lty = 2)

# Predict the leaf assigned to each sample.
node.id <- predict(tree, X, type = "node.id")
# Can be reshaped to a list of samples per leaf node with `split`.
samples.per.leaf <- split(1:n, node.id)

# The value of all arms (along with SEs) by each leaf node.
values <- aggregate(dr.scores, by = list(leaf.node = node.id),
                    FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x))))
print(values, digits = 2)
# }

Run the code above in your browser using DataLab