Learn R Programming

LBBNN (version 0.1.2)

validate_LBBNN: Validate a trained LBBNN model.

Description

Computes metrics on a validation dataset without computing gradients. Supports model averaging (recommended) by sampling from the variational posterior (num_samples > 1) to improve predictions. Returns metrics for both the full model and the sparse model.

Usage

validate_LBBNN(LBBNN, num_samples, test_dl, device = "cpu")

Value

A list containing the following elements:

accuracy_full_model

Classification accuracy of the full (dense) model (if classification).

accuracy_sparse

Classification accuracy using only weights in active paths (if classification).

validation_error

Root mean squared error for the full model (if regression).

validation_error_sparse

Root mean squared error using only weights in active paths (if regression).

density

Proportion of weights with posterior inclusion probability > 0.5 in the whole network.

density_active_path

Proportion of weights with inclusion probability > 0.5 after removing weights not in active paths.

Arguments

LBBNN

An instance of a trained LBBNN_Net to be validated.

num_samples

integer, the number of samples from the variational posterior to be used for model averaging.

test_dl

An instance of torch::dataloader, containing the validation data.

device

The device to perform validation on. Default is 'cpu'; other options include 'gpu' and 'mps'.