Learn R Programming

luz

Luz is a higher level API for torch providing abstractions to allow for much less verbose training loops.

This package is still under development.

It is heavily inspired by other higher level frameworks for deep learning, to cite a few:

  • FastAI: we are heavily inspired by the FastAI library, especially the Learner object and the callbacks API.

  • Keras: We are also heavily inspired by Keras, especially callback names. The lightning module interface is similar to compile, too.

  • PyTorch Lightning: The idea of the luz_module being a subclass of nn_module is inspired by the LightningModule object in lightning.

  • HuggingFace Accelerate: The internal device placement API is heavily inspired by Accelerate, but is much more modest in features. Currently only CPU and Single GPU are supported.

Installation

You can install the released version from CRAN with:

install.packages("luz")

or the development version with:

remotes::install_github("mlverse/luz")

Example

Luz lets you take your torch nn_module definition and fit it to a dataloader, while handling the boring parts like moving data between devices, updating the weights, showing progress bars and tracking metrics.

Here's an example defining and training an Autoencoder for the MNIST dataset. We selected parts of the code to highlight luz functionality. You can find the full example code here.

net <- nn_module(
  "Net",
  initialize = function() {
    self$encoder <- nn_sequential(
      nn_conv2d(1, 6, kernel_size=5),
      nn_relu(),
      nn_conv2d(6, 16, kernel_size=5),
      nn_relu()
    )
    self$decoder <- nn_sequential(
      nn_conv_transpose2d(16, 6, kernel_size = 5),
      nn_relu(),
      nn_conv_transpose2d(6, 1, kernel_size = 5),
      nn_sigmoid()
    )
  },
  forward = function(x) {
    x %>%
      self$encoder() %>%
      self$decoder()
  }
)

Now that we have defined the Autoencoder architecture using torch::nn_module(), we can fit it using luz:

fitted <- net %>%
  setup(
    loss = nn_mse_loss(),
    optimizer = optim_adam
  ) %>%
  fit(train_dl, epochs = 1, valid_data = test_dl)

Copy Link

Version

Install

install.packages('luz')

Monthly Downloads

1,429

Version

0.4.0

License

MIT + file LICENSE

Issues

Pull Requests

Stars

Forks

Maintainer

Daniel Falbel

Last Published

April 17th, 2023

Functions in luz (0.4.0)

ctx

Context object
context

Context object
evaluate

Evaluates a fitted model on a dataset
lr_finder

Learning Rate Finder
luz_callback_auto_resume

Resume training callback
luz_callback

Create a new callback
as_dataloader

Creates a dataloader from its input
fit.luz_module_generator

Fit a nn_module
accelerator

Create an accelerator
get_metrics

Get metrics from the object
luz_callback_keep_best_model

Keep the best model
luz_callback_metrics

Metrics callback
luz_callback_early_stopping

Early stopping callback
luz_callback_profile

Profile callback
luz_callback_gradient_clip

Gradient clipping callback
luz_callback_mixup

Mixup callback
luz_callback_model_checkpoint

Checkpoints model weights
luz_callback_interrupt

Interrupt callback
luz_callback_csv_logger

CSV logger callback
luz_load_model_weights

Loads model weights into a fitted object.
luz_metric_accuracy

Accuracy
luz_callback_resume_from_checkpoint

Allow resume model training from a specific checkpoint
luz_callback_lr_scheduler

Learning rate scheduler callback
luz_load

Load trained model
luz_metric_mse

Mean squared error
luz_load_checkpoint

Loads a checkpoint
luz_callback_progress

Progress callback
nn_mixup_loss

Loss to be used with callbacks_mixup().
nnf_mixup

Mixup logic
luz_metric_mae

Mean absolute error
luz_callback_tfevents

tfevents callback
luz_callback_train_valid

Train-eval callback
luz_save

Saves luz objects to disk
set_hparams

Set hyper-parameter of a module
luz_metric_set

Creates a metric set
reexports

Objects exported from other packages
luz_metric_binary_accuracy

Binary accuracy
luz_metric

Creates a new luz metric
set_opt_hparams

Set optimizer hyper-parameters
luz_metric_binary_auroc

Computes the area under the ROC
luz_metric_binary_accuracy_with_logits

Binary accuracy with logits
%>%

Pipe operator
predict.luz_module_fitted

Create predictions for a fitted model
luz_metric_rmse

Root mean squared error
luz_metric_multiclass_auroc

Computes the multi-class AUROC
setup

Set's up a nn_module to use with luz