Compute predictions from party
objects.
# S3 method for party
predict(object, newdata = NULL, perm = NULL, …)
predict_party(party, id, newdata = NULL, …)
# S3 method for default
predict_party(party, id, newdata = NULL, FUN = NULL, …)
# S3 method for constparty
predict_party(party, id, newdata = NULL,
type = c("response", "prob", "quantile", "density", "node"),
at = if (type == "quantile") c(0.1, 0.5, 0.9),
FUN = NULL, simplify = TRUE, …)
# S3 method for simpleparty
predict_party(party, id, newdata = NULL,
type = c("response", "prob", "node"), …)
objects of class party
.
an optional data frame in which to look for variables with which to predict, if omitted, the fitted values are used.
an optional character vector of variable names. Splits of
nodes with a primary split in any of these variables will
be permuted (after dealing with surrogates). Note that
surrogate split in the perm
variables will no be permuted.
objects of class party
.
a vector of terminal node identifiers.
a character string denoting the type of predicted value
returned, ignored when argument FUN
is given. For
"response"
, the mean of a numeric response, the predicted
class for a categorical response or the median survival time for a
censored response is returned. For "prob"
the matrix of
conditional class probabilities (simplify = TRUE
) or a list
with the conditional class probabilities for each observation
(simplify = FALSE
) is returned for a categorical response.
For numeric and censored responses, a list with the empirical
cumulative distribution functions and empirical survivor functions
(Kaplan-Meier estimate) is returned when type = "prob"
.
"node"
returns an integer vector of terminal node
identifiers.
a function to extract (default
method) or compute
(constparty
method) summary statistics. For the default
method,
this is a function of a terminal node only, for the constparty
method,
predictions for each node have to be computed based on arguments (y, w)
where y
is the response and w
are case weights.
if the return value is a function (as the empirical cumulative distribution
function or the empirical quantile function), this function is evaluated
at values at
and these numeric values are returned. If at
is
NULL
, the functions themselves are returned in a list.
a logical indicating whether the resulting list of predictions should be converted to a suitable vector or matrix (if possible).
additional arguments.
A list of predictions, possibly simplified to a numeric vector, numeric matrix or factor.
The predict
method for party
objects
computes the identifiers of the predicted terminal nodes, either
for new data in newdata
or for the learning samples
(only possible for objects of class constparty
).
These identifiers are delegated to the corresponding
predict_party
method which computes (via
FUN
for class constparty
)
or extracts (class simpleparty
) the actual predictions.
# NOT RUN {
## fit tree using rpart
library("rpart")
rp <- rpart(skips ~ Opening + Solder + Mask + PadType + Panel,
data = solder, method = 'anova')
## coerce to `constparty'
pr <- as.party(rp)
## mean predictions
predict(pr, newdata = solder[c(3, 541, 640),])
## ecdf
predict(pr, newdata = solder[c(3, 541, 640),], type = "prob")
## terminal node identifiers
predict(pr, newdata = solder[c(3, 541, 640),], type = "node")
## median predictions
predict(pr, newdata = solder[c(3, 541, 640),],
FUN = function(y, w = 1) median(y))
# }
Run the code above in your browser using DataLab