Learn R Programming

policytree (version 1.0.4)

policy_tree: Fit a policy with exact tree search

Description

Finds the optimal (maximizing the sum of rewards) depth L tree by exhaustive search. If the optimal action is the same in both the left and right leaf of a node, the node is pruned.

Usage

policy_tree(X, Gamma, depth = 2, split.step = 1, min.node.size = 1)

Arguments

X

The covariates used. Dimension \(Np\) where \(p\) is the number of features.

Gamma

The rewards for each action. Dimension \(Nd\) where \(d\) is the number of actions.

depth

The depth of the fitted tree. Default is 2.

split.step

An optional approximation parameter (integer above zero), the number of possible splits to consider when performing tree search. split.step = 1 (default) considers every possible split, split.step = 10 considers splitting at every 10'th sample and may yield a substantial speedup for dense features. Manually rounding or re-encoding continuous covariates with very high cardinality in a problem specific manner allows for finer-grained control of the accuracy/runtime tradeoff and may in some cases be the preferred approach over this option.

min.node.size

An integer indicating the smallest terminal node size permitted. Default is 1.

Value

A policy_tree object.

Details

The amortized runtime of the exact tree search is \(O(p^k n^k (log n + d) + pnlog n)\) where p is the number of features, d the number of treatments, n the number of observations, and \(k \geq 1\) the tree depth.

For a depth two tree this is \(O(p^2 n^2 (log n + d))\) (ignoring the last term which is a global sort done at the beginning) meaning that it scales quadratically with the number of observations, i.e. if you double the number of observations, the search will take at least four times as long.

For a depth three tree it is \(O(p^3 n^3 (log n + d))\). If a depth two tree with 1000 observations, 4 features and 3 actions took around t seconds, you can expect the level three tree to take approximately \(1000\cdot 4\) times as long (\(\approx\frac{p^3n^2}{p^2n^2}=pn\))

The runtime above is with continuous features. There are considerable time savings when the features are discrete. In the extreme case with all binary observations, the runtime will be practically linear in n.

The optional approximation parameter split.step emulates rounding the data and is recommended to experiment with in order to reduce the runtime.

References

Sverdrup, Erik, Ayush Kanodia, Zhengyuan Zhou, Susan Athey, and Stefan Wager. "policytree: Policy learning via doubly robust empirical welfare maximization over trees." Journal of Open Source Software 5, no. 50 (2020): 2232.

Zhou, Zhengyuan, Susan Athey, and Stefan Wager. "Offline multi-action policy learning: Generalization and optimization." arXiv preprint arXiv:1810.04778 (2018).

Examples

Run this code
# NOT RUN {
# Fit a depth two tree on doubly robust treatment effect estimates
# from a causal forest.
n <- 10000
p <- 5
X <- round(matrix(rnorm(n * p), n, p), 2)
W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
Y <- X[, 3] + W * tau + rnorm(n)
c.forest <- grf::causal_forest(X, Y, W)
dr.scores <- double_robust_scores(c.forest)

tree <- policy_tree(X, dr.scores, 2)
tree

# Predict treatment assignment.
predicted <- predict(tree, X)

plot(X[, 1], X[, 2], col = predicted)
legend("topright", c("control", "treat"), col = c(1, 2), pch = 19)
abline(0, -1, lty = 2)

# Predict the leaf assigned to each sample.
node.id <- predict(tree, X, type = "node.id")
# Can be reshaped to a list of samples per leaf node with `split`.
samples.per.leaf <- split(1:n, node.id)

# The value of all arms (along with SEs) by each leaf node.
values <- aggregate(dr.scores, by = list(leaf.node = node.id),
                    FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x))))
print(values, digits = 2)
# }

Run the code above in your browser using DataLab