Learn R Programming

LBBNN (version 0.1.2)

train_LBBNN: Train an instance of LBBNN_Net.

Description

Function that for each epoch iterates through each mini-batch, computing the loss and using back-propagation to update the network parameters.

Usage

train_LBBNN(
  epochs,
  LBBNN,
  lr,
  train_dl,
  device = "cpu",
  scheduler = NULL,
  sch_step_size = NULL
)

Value

a list containing the losses and accuracy (if classification) and density for each epoch during training. For comparisons sake we show the density with and without active paths.

A list with elements (returned invisibly):

accs

Vector of accuracy per epoch (classification only).

loss

Vector of average loss per epoch.

density

Vector of network densities per epoch.

Arguments

epochs

integer, total number of epochs to train for, where one epoch is a pass through the entire training dataset (all mini batches).

LBBNN

An instance of LBBNN_Net, to be trained.

lr

numeric, the learning rate to be used in the Adam optimizer.

train_dl

An instance of torch::dataloader consisting of a tensor dataset with features and targets.

device

the device to be trained on. Default is 'cpu', also accepts 'gpu' or 'mps'.

scheduler

A torch learning rate scheduler object. Can be used to decay learning rate for better convergence, currently only supports 'step'.

sch_step_size

Where to decay if using torch::lr_step. E.g. 1000 means learning rate is decayed every 1000 epochs.

Examples

Run this code
# \donttest{ 
x<-torch::torch_randn(3,2) 
b <- torch::torch_rand(2)
y <- torch::torch_matmul(x,b)
train_data <- torch::tensor_dataset(x,y)
train_loader <- torch::dataloader(train_data,batch_size = 3,shuffle=FALSE)
problem<-'regression'
sizes <- c(2,1,1) 
inclusion_priors <-c(0.9,0.2) 
inclusion_inits <- matrix(rep(c(-10,10),2),nrow = 2,ncol = 2)
stds <- c(1.0,1.0)
model <- LBBNN_Net(problem,sizes,inclusion_priors,stds,inclusion_inits,flow = FALSE)
output <- train_LBBNN(epochs = 1,LBBNN = model, lr = 0.01,train_dl = train_loader)
# }

Run the code above in your browser using DataLab