Learn R Programming

StratifiedMedicine (version 0.1.3)

ple_causal_forest: Patient-level Estimates: Causal Forest

Description

Uses the causal forest algorithm (Athey, Tibshirani, and Wager 2019; grf R package) to obtain patient-level estimates. Used for continuous or binary outcomes.

Usage

ple_causal_forest(Y, A, X, Xtest, tune = FALSE, num.trees = 500,
  family = "gaussian", mod.A = "mean", ...)

Arguments

Y

The outcome variable. Must be numeric or survival (ex; Surv(time,cens) )

A

Treatment variable. (a=1,...A)

X

Covariate space.

Xtest

Test set

tune

If TRUE, use grf automatic hyper-parameter tuning. If FALSE (default), no tuning.

num.trees

Number of trees (default=500)

family

Outcome type ("gaussian", "binomial"), default is "gaussian"

mod.A

Model for estimating P(A|X). Default is "mean" calculates the sample mean. If mod.A="RF", estimate P(A|X) using regression_forest (applicable for non-RCTs).

...

Any additional parameters, not currently passed through.

Value

Trained causal_forest and regression_forest models.

  • mods - trained model(s)

Examples

Run this code
# NOT RUN {
library(StratifiedMedicine)

## Continuous ##
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X
A = dat_ctns$A

# }
# NOT RUN {
require(grf)
mod1 = ple_causal_forest(Y, A, X, Xtest=X)
summary(mod1$mu_train)

# }
# NOT RUN {

# }

Run the code above in your browser using DataLab