train
object or to loop through a number of train
objects to calculate the
training and test data predictions and class probabilities.## S3 method for class 'list':
predict(object, ...)## S3 method for class 'train':
predict(object, newdata = NULL, type = "raw", ...)
extractPrediction(models,
testX = NULL, testY = NULL,
unkX = NULL,
unkOnly = !is.null(unkX) & is.null(testX),
verbose = FALSE)
extractProb(models,
testX = NULL, testY = NULL,
unkX = NULL,
unkOnly = !is.null(unkX) & is.null(testX),
verbose = FALSE)
NULL
, then the original training data are usedtrain
. The objects must have been generated with
fitBest = FALSE
and returnData = TRUE
.testX
predict.train
, a vector of predictions if type = "raw"
or a data frame of class probabilities for type = "probs"
. In the latter case, there are columns for each class. For predict.list
, a list results. Each element is produced by predict.train
.
For extractPrediction
, a data frame with columns:
models
. If models
is an un-named list, the values of object
will be "Object1", "Object2" and so onextractProb
, a data frame. There is a column for each class
containing the probabilities. The remaining columns are the same as
above (although the pred
column is the predicted class)tuneValue
slot of the finalModel
object are used to predict.To get simple predictions for a new data set, the predict
function can be used. Limits can be imposed on the range of predictions. See trainControl
for more information.
To get predictions for a series of models at once, a list of train
objects can be passes to the predict
function and a list of model predictions will be returned.
The two extraction functions can be used to get the predictions and observed outcomes at once for the training, test and/or unknown samples at once in a single data frame (instead of a list of just the predictions). These objects can then be passes to plotObsVsPred
or plotClassProbs
.
plotObsVsPred
, plotClassProbs
, trainControl
library(mlbench)
data(Satellite)
numSamples <- dim(Satellite)[1]
set.seed(716)
varIndex <- 1:numSamples
trainSamples <- sample(varIndex, 150)
varIndex <- (1:numSamples)[-trainSamples]
testSamples <- sample(varIndex, 100)
varIndex <- (1:numSamples)[-c(testSamples, trainSamples)]
unkSamples <- sample(varIndex, 50)
trainX <- Satellite[trainSamples, -37]
trainY <- Satellite[trainSamples, 37]
testX <- Satellite[testSamples, -37]
testY <- Satellite[testSamples, 37]
unkX <- Satellite[unkSamples, -37]
knnFit <- train(trainX, trainY, "knn")
rpartFit <- train(trainX, trainY, "rpart")
predict(knnFit)
predict(knnFit, newdata = testX)
predict(knnFit, type = "prob")
bothModels <- list(
knn = knnFit,
tree = rpartFit)
predict(bothModels)
predTargets <- extractPrediction(
bothModels,
testX = testX,
testY = testY,
unkX = unkX)
Run the code above in your browser using DataLab