Evaluates a fitted model on a dataset
evaluate(
object,
data,
...,
metrics = NULL,
callbacks = list(),
accelerator = NULL,
verbose = NULL,
dataloader_options = NULL
)
A fitted model to evaluate.
(dataloader, dataset or list) A dataloader created with
torch::dataloader()
used for training the model, or a dataset created
with torch::dataset()
or a list. Dataloaders and datasets must return a
list with at most 2 items. The first item will be used as input for the
module and the second will be used as a target for the loss function.
Currently unused.
A list of luz metrics to be tracked during evaluation. If NULL
(default) then the same metrics that were used during training are tracked.
(list, optional) A list of callbacks defined with
luz_callback()
that will be called during the training procedure. The
callbacks luz_callback_metrics()
, luz_callback_progress()
and
luz_callback_train_valid()
are always added by default.
(accelerator, optional) An optional accelerator()
object
used to configure device placement of the components like nn_modules,
optimizers and batches of data.
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if interactive()
is TRUE
, otherwise
it won't print to the console.
Options used when creating a dataloader. See
torch::dataloader()
. shuffle=TRUE
by default for the training data and
batch_size=32
by default. It will error if not NULL
and data
is
already a dataloader.
Once a model has been trained you might want to evaluate its performance
on a different dataset. For that reason, luz provides the ?evaluate
function that takes a fitted model and a dataset and computes the
metrics attached to the model.
Evaluate returns a luz_module_evaluation
object that you can query for
metrics using the get_metrics
function or simply print
to see the
results.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)
## A `luz_module_evaluation`
## -- Results ---------------------------------------------------------------------
## loss: 1.5146
## mae: 1.0251
## mse: 1.5159
## rmse: 1.2312
Other training:
fit.luz_module_generator()
,
predict.luz_module_fitted()
,
setup()