rule_fit()
is a way to generate a specification of a model
before fitting. The main arguments for the model are:
mtry
: The number of predictors that will be
randomly sampled at each split when creating the tree models.
trees
: The number of trees contained in the ensemble.
min_n
: The minimum number of data points in a node
that are required for the node to be split further.
tree_depth
: The maximum depth of the tree (i.e. number of
splits).
learn_rate
: The rate at which the boosting algorithm adapts
from iteration-to-iteration.
loss_reduction
: The reduction in the loss function required
to split further.
sample_size
: The amount of data exposed to the fitting routine.
penalty
: The amount of regularization in the glmnet model.
These arguments are converted to their specific names at the
time that the model is fit. Other options and argument can be
set using parsnip::set_engine()
. If left to their defaults
here (NULL
), the values are taken from the underlying model
functions. If parameters need to be modified, update()
can be used
in lieu of recreating the object from scratch.
rule_fit(
mode = "unknown",
mtry = NULL,
trees = NULL,
min_n = NULL,
tree_depth = NULL,
learn_rate = NULL,
loss_reduction = NULL,
sample_size = NULL,
penalty = NULL
)# S3 method for rule_fit
update(
object,
parameters = NULL,
mtry = NULL,
trees = NULL,
min_n = NULL,
tree_depth = NULL,
learn_rate = NULL,
loss_reduction = NULL,
sample_size = NULL,
penalty = NULL,
fresh = FALSE,
...
)
A single character string for the type of model. Possible values for this model are "unknown", "regression", or "classification".
An number for the number (or proportion) of predictors that will be randomly sampled at each split when creating the tree models.
An integer for the number of trees contained in the ensemble.
An integer for the minimum number of data points in a node that are required for the node to be split further.
An integer for the maximum depth of the tree (i.e. number of splits).
A number for the rate at which the boosting algorithm adapts from iteration-to-iteration.
A number for the reduction in the loss function required to split further .
An number for the number (or proportion) of data that is exposed to the fitting routine.
L1 regularization parameter.
A rule_fit
model specification.
A 1-row tibble or named list with main parameters to update. If the individual arguments are used, these will supersede the values in parameters. Also, using engine arguments in this object will result in an error.
A logical for whether the arguments should be modified in-place or replaced wholesale.
Not used for update()
.
An updated parsnip
model specification.
The RuleFit model creates a regression model of rules in two stages. The first stage uses a tree-based model that is used to generate a set of rules that can be filtered, modified, and simplified. These rules are then added as predictors to a regularized generalized linear model that can also conduct feature selection during model training.
For the xrf
engine, the xgboost
package is used to create the rule set
that is then added to a glmnet
model. The only available engine is "xrf"
.
Note that, per the documentation in
?xrf
, transformations of the response variable are not supported. To
use these with rule_fit()
, we recommend using a recipe instead of the
formula method.
Also, there are several configuration differences in how xrf()
is fit
between that package and the wrapper used in rules
. Some differences in
default values are:
trees
(xrf: 100, rules: 15)
max_depth
(xrf: 3, rules: 6)
These differences will create a difference in the values of the penalty
argument that glmnet
uses. Also, rules
can also set penalty
whereas
xrf
uses an internal 5-fold cross-validation to determine it (by default).
Friedman, J. H., and Popescu, B. E. (2008). "Predictive learning via rule ensembles." The Annals of Applied Statistics, 2(3), 916-954.
# NOT RUN {
rule_fit()
# Parameters can be represented by a placeholder:
rule_fit(trees = 7)
# ------------------------------------------------------------------------------
set.seed(6907)
rule_fit_rules <-
rule_fit(trees = 3, penalty = 0.1) %>%
set_mode("classification") %>%
fit(Species ~ ., data = iris)
# ------------------------------------------------------------------------------
model <- rule_fit(trees = 10, min_n = 2)
model
update(model, trees = 1)
update(model, trees = 1, fresh = TRUE)
# }
Run the code above in your browser using DataLab