Learn R Programming

shrinkGPR (version 1.0.0)

sylvester: Sylvester Normalizing Flow

Description

The sylvester function implements Sylvester normalizing flows as described by van den Berg et al. (2018) in "Sylvester Normalizing Flows for Variational Inference." This flow applies a sequence of invertible transformations to map a simple base distribution to a more complex target distribution, allowing for flexible posterior approximations in Gaussian process regression models.

Usage

sylvester(d, n_householder)

Value

An nn_module object representing the Sylvester normalizing flow. The module has the following key components:

  • forward(zk): The forward pass computes the transformed variable z and the log determinant of the Jacobian.

  • Internal parameters include matrices R1 and R2, diagonal elements, and Householder reflections used for orthogonalization.

Arguments

d

An integer specifying the latent dimensionality of the input space.

n_householder

An optional integer specifying the number of Householder reflections used to orthogonalize the transformation. Defaults to d - 1.

Details

The Sylvester flow uses two triangular matrices (R1 and R2) and Householder reflections to construct invertible transformations. The transformation is parameterized as follows: $$z = Q R_1 h(Q^T R_2 zk + b) + zk,$$ where:

  • Q is an orthogonal matrix obtained via Householder reflections.

  • R1 and R2 are upper triangular matrices with learned diagonal elements.

  • h is a non-linear activation function (default: torch_tanh).

  • b is a learned bias vector.

The log determinant of the Jacobian is computed to ensure the invertibility of the transformation and is given by: $$\log |det J| = \sum_{i=1}^d \log |diag_1[i] \cdot diag_2[i] \cdot h'(RQ^T zk + b) + 1|,$$ where diag_1 and diag_2 are the learned diagonal elements of R1 and R2, respectively, and h\' is the derivative of the activation function.

References

van den Berg, R., Hasenclever, L., Tomczak, J. M., & Welling, M. (2018). "Sylvester Normalizing Flows for Variational Inference." Proceedings of the Thirty-Fourth Conference on Uncertainty in Artificial Intelligence (UAI 2018).

Examples

Run this code
if (torch::torch_is_installed()) {
  # Example: Initialize a Sylvester flow
  d <- 5
  n_householder <- 4
  flow <- sylvester(d, n_householder)

  # Forward pass through the flow
  zk <- torch::torch_randn(10, d)  # Batch of 10 samples
  result <- flow(zk)

  print(result$zk)
  print(result$log_diag_j)
}

Run the code above in your browser using DataLab