Learn R Programming

cito (version 1.1)

predict.citodnn: Predict from a fitted dnn model

Description

Predict from a fitted dnn model

Usage

# S3 method for citodnn
predict(
  object,
  newdata = NULL,
  type = c("link", "response", "class"),
  device = c("cpu", "cuda", "mps"),
  reduce = c("mean", "median", "none"),
  ...
)

# S3 method for citodnnBootstrap predict( object, newdata = NULL, type = c("link", "response", "class"), device = c("cpu", "cuda", "mps"), reduce = c("mean", "median", "none"), ... )

Value

prediction matrix

Arguments

object

a model created by dnn

newdata

new data for predictions

type

type of predictions. The default is on the scale of the linear predictor, "response" is on the scale of the response, and "class" means that class predictions are returned (if it is a classification task)

device

device on which network should be trained on.

reduce

predictions from bootstrapped model are by default reduced (mean, optional median or none)

...

additional arguments

Examples

Run this code
# \donttest{
if(torch::torch_is_installed()){
library(cito)

set.seed(222)
validation_set<- sample(c(1:nrow(datasets::iris)),25)

# Build and train  Network
nn.fit<- dnn(Sepal.Length~., data = datasets::iris[-validation_set,])

# Use model on validation set
predictions <- predict(nn.fit, iris[validation_set,])
# Scatterplot
plot(iris[validation_set,]$Sepal.Length,predictions)
# MAE
mean(abs(predictions-iris[validation_set,]$Sepal.Length))
}
# }

Run the code above in your browser using DataLab