ml_random_forest
Spark ML -- Random Forests
Perform regression or classification using random forests with a Spark DataFrame.
Usage
ml_random_forest(x, response, features, col.sample.rate = NULL,
impurity = c("auto", "gini", "entropy", "variance"), max.bins = 32L,
max.depth = 5L, min.info.gain = 0, min.rows = 1L, num.trees = 20L,
sample.rate = 1, thresholds = NULL, seed = NULL, type = c("auto",
"regression", "classification"), checkpoint.interval = 10L,
cache.node.ids = FALSE, max.memory = 256L, ml.options = ml_options(),
...)
Arguments
- x
An object coercable to a Spark DataFrame (typically, a
tbl_spark
).- response
The name of the response vector (as a length-one character vector), or a formula, giving a symbolic description of the model to be fitted. When
response
is a formula, it is used in preference to other parameters to set theresponse
,features
, andintercept
parameters (if available). Currently, only simple linear combinations of existing parameters is supposed; e.g.response ~ feature1 + feature2 + ...
. The intercept term can be omitted by using- 1
in the model fit.- features
The name of features (terms) to use for the model fit.
- col.sample.rate
The sampling rate of features to consider for splits at each tree node. Defaults to 1/3 for regression and sqrt(k)/k for classification where k is number of features. For Spark versions prior to 2.0.0, arbitrary sampling rates are not supported, so the input is automatically mapped to one of "onethird", "sqrt", or "log2".
- impurity
Criterion used for information gain calculation One of 'auto', 'gini', 'entropy', or 'variance'. 'auto' defaults to 'gini' for classification and 'variance' for regression.
- max.bins
The maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity.
- max.depth
Maximum depth of the tree (>= 0); that is, the maximum number of nodes separating any leaves from the root of the tree.
- min.info.gain
Minimum information gain for a split to be considered at a tree node. Should be >= 0, defaults to 0.
- min.rows
Minimum number of instances each child must have after split.
- num.trees
Number of trees to train (>= 1), defaults to 20.
- sample.rate
Fraction of the training data used for learning each decision tree, defaults to 1.0.
- thresholds
Thresholds in multi-class classification to adjust the probability of predicting each class. Vector must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
- seed
Seed for random numbers.
- type
The type of model to fit.
"regression"
treats the response as a continuous variable, while"classification"
treats the response as a categorical variable. When"auto"
is used, the model type is inferred based on the response variable type -- if it is a numeric type, then regression is used; classification otherwise.- checkpoint.interval
Set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations, defaults to 10.
- cache.node.ids
If
FALSE
, the algorithm will pass trees to executors to match instances with nodes. IfTRUE
, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Defaults toFALSE
.- max.memory
Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. Defaults to 256.
- ml.options
Optional arguments, used to affect the model generated. See
ml_options
for more details.- ...
Optional arguments. The
data
argument can be used to specify the data to be used whenx
is a formula; this allows calls of the formml_linear_regression(y ~ x, data = tbl)
, and is especially useful in conjunction withdo
.
See Also
Other Spark ML routines: ml_als_factorization
,
ml_decision_tree
,
ml_generalized_linear_regression
,
ml_gradient_boosted_trees
,
ml_kmeans
, ml_lda
,
ml_linear_regression
,
ml_logistic_regression
,
ml_multilayer_perceptron
,
ml_naive_bayes
,
ml_one_vs_rest
, ml_pca
,
ml_survival_regression