gbt.train
is an interface for training an agtboost model.
gbt.train(
y,
x,
learning_rate = 0.01,
loss_function = "mse",
nrounds = 50000,
verbose = 0,
gsub_compare,
algorithm = "global_subset",
previous_pred = NULL,
weights = NULL,
force_continued_learning = FALSE,
offset = NULL,
...
)
response vector for training. Must correspond to the design matrix x
.
design matrix for training. Must be of type matrix
.
control the learning rate: scale the contribution of each tree by a factor of 0 < learning_rate < 1
when it is added to the current approximation. Lower value for learning_rate
implies an increase in the number of boosting iterations: low learning_rate
value means model more robust to overfitting but slower to compute. Default: 0.01
specify the learning objective (loss function). Only pre-specified loss functions are currently supported.
mse
regression with squared error loss (Default).
logloss
logistic regression for binary classification, output score before logistic transformation.
poisson
Poisson regression for count data using a log-link, output score before natural transformation.
gamma::neginv
gamma regression using the canonical negative inverse link. Scaling independent of y.
gamma::log
gamma regression using the log-link. Constant information parametrisation.
negbinom
Negative binomial regression for count data with overdispersion. Log-link.
count::auto
Chooses automatically between Poisson or negative binomial regression.
a just-in-case max number of boosting iterations. Default: 50000
Enable boosting tracing information at i-th iteration? Default: 0
.
Deprecated. Boolean: Global-subset comparisons. FALSE
means standard GTB, TRUE
compare subset-splits with global splits (next root split). Default: TRUE
.
specify the algorithm used for gradient tree boosting.
vanilla
ordinary gradient tree boosting. Trees are optimized as if they were the last tree.
global_subset
function-change to target maximized reduction in generalization loss for individual datapoints
prediction vector for training. Boosted training given predictions from another model.
weights vector for scaling contributions of individual observations. Default NULL
(the unit vector).
Boolean: FALSE
(default) stops at information stopping criterion, TRUE
stops at nround
iterations.
add offset to the model g(mu) = offset + F(x).
additional parameters passed.
if loss_function is 'negbinom', dispersion must be provided in ...
An object of class ENSEMBLE
with some or all of the following elements:
handle
a handle (pointer) to the agtboost model in memory.
initialPred
a field containing the initial prediction of the ensemble.
set_param
function for changing the parameters of the ensemble.
train
function for re-training (or from scratch) the ensemble directly on vector y
and design matrix x
.
predict
function for predicting observations given a design matrix
predict2
function as above, but takes a parameter max number of boosting ensemble iterations.
estimate_generalization_loss
function for calculating the (approximate) optimism of the ensemble.
get_num_trees
function returning the number of trees in the ensemble.
These are the training functions for an agtboost.
Explain the philosophy and the algorithm and a little math
gbt.train
learn trees with adaptive complexity given by an information criterion,
until the same (but scaled) information criterion tells the algorithm to stop. The data used
for training at each boosting iteration stems from a second order Taylor expansion to the loss
function, evaluated at predictions given by ensemble at the previous boosting iteration.
Berent <U+00C5>nund Str<U+00F8>mnes Lunde, Tore Selland Kleppe and Hans Julius Skaug, "An Information Criterion for Automatic Gradient Tree Boosting", 2020, https://arxiv.org/abs/2008.05926
# NOT RUN {
## A simple gtb.train example with linear regression:
x <- runif(500, 0, 4)
y <- rnorm(500, x, 1)
x.test <- runif(500, 0, 4)
y.test <- rnorm(500, x.test, 1)
mod <- gbt.train(y, as.matrix(x))
y.pred <- predict( mod, as.matrix( x.test ) )
plot(x.test, y.test)
points(x.test, y.pred, col="red")
# }
Run the code above in your browser using DataLab