Wrapper function to train a subgroup model (submod). Used directly in PRISM and can be used to directly fit a submod model by name.
submod_train(Y, A, X, Xtest, mu_train = NULL, family = "gaussian",
submod, hyper = NULL, ...)
The outcome variable. Must be numeric or survival (ex; Surv(time,cens) )
Treatment variable. (a=1,...A)
Covariate space.
Test set
Patient-level estimates (See PLE_models). Default=NULL
Outcome type ("gaussian", "binomial", "survival"). Default="gaussian".
Subgroup identification (submod) function. Maps the observed data and/or PLEs to subgroups.
Hyper-parameters for submod (must be list). Default is NULL.
Any additional parameters, not currently passed through.
Trained subgroup model and subgroup predictions/estimates for train/test sets.
fit - trained subgroup model
Subgrps.train - Identified subgroups (training set)
Subgrps.test - Identified subgroups (test set)
pred.train - Predictions (training set)
pred.test - Predictions (test set)
Rules - Definitions for subgroups, if provided in fitted submod output.
# NOT RUN {
library(StratifiedMedicine)
## Continuous ##
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X
A = dat_ctns$A
# Fit submod_lmtree directly #
mod1 = submod_lmtree(Y, A, X, Xtest=X)
plot(mod1$mod)
# Fit through submod_train wrapper #
mod2 = submod_train(Y=Y, A=A, X=X, Xtest=X, submod="submod_lmtree")
plot(mod2$fit$mod)
# }
Run the code above in your browser using DataLab