Generate predictions from an "nn_fit" object produced by train_nn().
Three S3 methods are registered:
predict.nn_fit() — base method for matrix-trained models.
predict.nn_fit_tab() — extends the base method for tabular fits; runs new
data through hardhat::forge() before predicting.
predict.nn_fit_ds() — extends the base method for torch dataset fits.
# S3 method for nn_fit
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)# S3 method for nn_fit_tab
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
# S3 method for nn_fit_ds
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
Regression: a numeric vector (single output) or matrix (multiple outputs).
Classification, type = "response": a factor with levels matching those
seen during training.
Classification, type = "prob": a numeric matrix with one column per
class, columns named by class label.
A fitted model object returned by train_nn().
New predictor data. Accepted forms depend on the method:
predict.nn_fit(): a numeric matrix or coercible object.
predict.nn_fit_tab(): a data.frame with the same columns used during
training; preprocessing is applied automatically via hardhat::forge().
predict.nn_fit_ds(): a torch dataset, numeric array, matrix, or
data.frame.
If NULL, the cached fitted values from training are returned (not
available for type = "prob").
Legacy alias for newdata. Retained for compatibility.
Character. Output type:
"response" (default): predicted class labels (factor) for
classification, or a numeric vector / matrix for regression.
"prob": a numeric matrix of class probabilities (classification only).
Currently unused; reserved for future extensions.
train_nn()