Learn R Programming

DeepLearningCausal (version 0.0.107)

metalearner_neural: metalearner_neural

Description

metalearner_neural implements the S-learner and T-learner for estimating CATE using Deep Neural Networks. The Resilient back propagation (Rprop) algorithm is used for training neural networks.

Usage

metalearner_neural(
  data,
  cov.formula,
  treat.var,
  meta.learner.type,
  stepmax = 1e+05,
  nfolds = 5,
  algorithm = "rprop+",
  hidden.layer = c(4, 2),
  act.fct = "logistic",
  err.fct = "sse",
  linear.output = TRUE,
  binary.preds = FALSE
)

Value

metalearner_neural of predicted outcome values and CATEs estimated by the meta learners for each observation.

Arguments

data

data.frame object of data.

cov.formula

formula description of the model y ~ x(list of covariates).

treat.var

string for the name of treatment variable.

meta.learner.type

string specifying is the S-learner and "T.Learner" for the T-learner model. "X.Learner" for the X-learner model. "R.Learner" for the R-learner model.

stepmax

maximum number of steps for training model.

nfolds

number of folds for cross-validation. Currently supports up to 5 folds.

algorithm

a string for the algorithm for the neural network. Default set to rprop+, the Resilient back propagation (Rprop) with weight backtracking algorithm for training neural networks.

hidden.layer

vector of integers specifying layers and number of neurons.

act.fct

"logistic" or "tanh" for activation function to be used in the neural network.

err.fct

"ce" for cross-entropy or "sse" for sum of squared errors as error function.

linear.output

logical specifying regression (TRUE) or classification (FALSE) model.

binary.preds

logical specifying predicted outcome variable will take binary values or proportions.

Examples

Run this code
# \donttest{
# load dataset
data(exp_data)
# estimate CATEs with S Learner
set.seed(123456)
slearner_nn <- metalearner_neural(cov.formula = support_war ~ age + income +
                                   employed  + job_loss,
                                   data = exp_data,
                                   treat.var = "strong_leader",
                                   meta.learner.type = "S.Learner",
                                   stepmax = 2e+9,
                                   nfolds = 5,
                                   algorithm = "rprop+",
                                   hidden.layer = c(1),
                                   linear.output = FALSE,
                                   binary.preds = FALSE)

print(slearner_nn)
# }
# \donttest{
# load dataset
set.seed(123456)
# estimate CATEs with T Learner
tlearner_nn <- metalearner_neural(cov.formula = support_war ~ age +
                                  income  +
                                  employed  + job_loss,
                                  data = exp_data,
                                  treat.var = "strong_leader",
                                  meta.learner.type = "T.Learner",
                                  stepmax = 1e+9,
                                  nfolds = 5,
                                  algorithm = "rprop+",
                                  hidden.layer = c(2,1),
                                  linear.output = FALSE,
                                  binary.preds = FALSE)

print(tlearner_nn)
# }

# \donttest{
# load dataset
set.seed(123456)
# estimate CATEs with X Learner
xlearner_nn <- metalearner_neural(cov.formula = support_war ~ age +
                                  income  +
                                  employed  + job_loss,
                                  data = exp_data,
                                  treat.var = "strong_leader",
                                  meta.learner.type = "X.Learner",
                                  stepmax = 2e+9,
                                  nfolds = 5,
                                  algorithm = "rprop+",
                                  hidden.layer = c(3),
                                  linear.output = FALSE,
                                  binary.preds = FALSE)

print(xlearner_nn)
                                  # }

Run the code above in your browser using DataLab