Learn R Programming

DeepLearningCausal (version 0.0.107)

metalearner_ensemble: metalearner_ensemble

Description

metalearner_ensemble implements the S-learner, T-learner, and X-learner for weighted ensemble learning estimation of CATEs using super learner. The super learner in this case includes the following machine learning algorithms: extreme gradient boosting, glmnet (elastic net regression), random forest and neural nets.

Usage

metalearner_ensemble(
  data = NULL,
  train.data = NULL,
  test.data = NULL,
  cov.formula,
  treat.var,
  meta.learner.type,
  SL.learners = c("SL.glmnet", "SL.xgboost", "SL.nnet"),
  nfolds = 5,
  family = gaussian(),
  binary.preds = FALSE,
  conformal = FALSE,
  alpha = 0.1,
  calib_frac = 0.5,
  seed = 1234
)

Value

metalearner_ensemble of predicted outcome values and CATEs estimated by the meta learners for each observation.

Arguments

data

data.frame object of data for cross-validation

train.data

data.frame object of training data argument to separately train the meta-learners on training data.

test.data

data.frame object of test data argument to estimate CATEs on the test data.

cov.formula

formula description of the model y ~ x(list of covariates) permits users to incorporate outcome variable and confounders in model.

treat.var

string for the name of treatment variable.

meta.learner.type

string specifying is the S-learner and "T.Learner" for the T-learner model. "X.Learner" for the X-learner model. "R.Learner" for the X-learner model.

SL.learners

vector for super learner ensemble that includes extreme gradient boosting, glmnet, random forest, and neural nets.

nfolds

number of folds for cross-validation. Currently supports up to 5 folds.

family

gaussian() or binomial() family for outcome variable. 5 folds.

binary.preds

logical for whether outcome predictions should be binary

conformal

logical for whether to compute conformal prediction intervals

alpha

proportion for conformal prediction intervals

calib_frac

fraction of training data to use for calibration in conformal inference

seed

random seed

Examples

Run this code
# load dataset
data(exp_data)
#load SuperLearner package
library(SuperLearner)
# estimate CATEs with S Learner
set.seed(123456)
slearner <- metalearner_ensemble(cov.formula = support_war ~ age +
                                  income + employed + job_loss,
                                data = exp_data,
                                treat.var = "strong_leader",
                                meta.learner.type = "S.Learner",
                                SL.learners = c("SL.glm"),
                                nfolds = 5,
                                binary.preds = FALSE,
                                )
print(slearner)

# \donttest{
# estimate CATEs with T Learner
set.seed(123456)
tlearner <- metalearner_ensemble(cov.formula = support_war ~ age + income +
                                  employed  + job_loss,
                                  data = exp_data,
                                  treat.var = "strong_leader",
                                  meta.learner.type = "T.Learner",
                                  SL.learners = c("SL.xgboost",
                                               "SL.nnet"),
                                  nfolds = 5,
                                  binary.preds = FALSE,
                                  )

print(tlearner)
                                  # }

# \donttest{
# estimate CATEs with X Learner
set.seed(123456)
xlearner <- metalearner_ensemble(cov.formula = support_war ~ age + income +
 employed  + job_loss,
                                 test.data = exp_data,
                                 train.data = exp_data,
                                 treat.var = "strong_leader",
                                 meta.learner.type = "X.Learner",
                                 SL.learners = c("SL.glmnet","SL.xgboost", 
                                 "SL.nnet"),
                                 binary.preds = TRUE)

print(xlearner)
                                  # }


Run the code above in your browser using DataLab