Prediction of artificial neural network of class nn
, produced by neuralnet()
.
# S3 method for nn
predict(object, newdata, rep = 1, all.units = FALSE, ...)
Neural network of class nn
.
New data of class data.frame
or matrix
.
Integer indicating the neural network's repetition which should be used.
Return output for all units instead of final output only.
further arguments passed to or from other methods.
Matrix of predictions. Each column represents one output unit.
If all.units=TRUE
, a list of matrices with output for each unit.
# NOT RUN { library(neuralnet) # Split data train_idx <- sample(nrow(iris), 2/3 * nrow(iris)) iris_train <- iris[train_idx, ] iris_test <- iris[-train_idx, ] # Binary classification nn <- neuralnet(Species == "setosa" ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE) pred <- predict(nn, iris_test) table(iris_test$Species == "setosa", pred[, 1] > 0.5) # Multiclass classification nn <- neuralnet((Species == "setosa") + (Species == "versicolor") + (Species == "virginica") ~ Petal.Length + Petal.Width, iris_train, linear.output = FALSE) pred <- predict(nn, iris_test) table(iris_test$Species, apply(pred, 1, which.max)) # }