nn_module
Fit a nn_module
# S3 method for luz_module_generator
fit(
object,
data,
epochs = 10,
callbacks = NULL,
valid_data = NULL,
accelerator = NULL,
verbose = NULL,
...,
dataloader_options = NULL
)
A fitted object that can be saved with luz_save()
and can be
printed with print()
and plotted with plot()
.
An nn_module
that has been setup()
.
(dataloader, dataset or list) A dataloader created with
torch::dataloader()
used for training the model, or a dataset created
with torch::dataset()
or a list. Dataloaders and datasets must return a
list with at most 2 items. The first item will be used as input for the
module and the second will be used as a target for the loss function.
(int) The maximum number of epochs for training the model. If a
single value is provided, this is taken to be the max_epochs
and
min_epochs
is set to 0. If a vector of two numbers is provided, the first
value is min_epochs
and the second value is max_epochs
. The minimum and
maximum number of epochs are included in the context object as
ctx$min_epochs
and ctx$max_epochs
, respectively.
(list, optional) A list of callbacks defined with
luz_callback()
that will be called during the training procedure. The
callbacks luz_callback_metrics()
, luz_callback_progress()
and
luz_callback_train_valid()
are always added by default.
(dataloader, dataset, list or scalar value; optional) A
dataloader created with torch::dataloader()
or a dataset created with
torch::dataset()
that will be used during the validation procedure. They
must return a list with (input, target). If data
is a torch dataset or a
list, then you can also supply a numeric value between 0 and 1 - and in
this case a random sample with size corresponding to that proportion from
data
will be used for validation.
(accelerator, optional) An optional accelerator()
object
used to configure device placement of the components like torch::nn_modules,
optimizers and batches of data.
(logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if interactive()
is TRUE
, otherwise
it won't print to the console.
Currently unused.
Options used when creating a dataloader. See
torch::dataloader()
. shuffle=TRUE
by default for the training data and
batch_size=32
by default. It will error if not NULL
and data
is
already a dataloader.
predict.luz_module_fitted()
for how to create predictions.
setup()
to find out how to create modules that can be trained with fit
.
Other training:
evaluate()
,
predict.luz_module_fitted()
,
setup()