Variational losses measure the divergence between an unnormalized target
distribution p (provided via target_log_prob_fn) and a surrogate
distribution q (provided as surrogate_posterior). When the
target distribution is an unnormalized posterior from conditioning a model on
data, minimizing the loss with respect to the parameters of
surrogate_posterior performs approximate posterior inference.
vi_monte_carlo_variational_loss(
target_log_prob_fn,
surrogate_posterior,
sample_size = 1,
discrepancy_fn = vi_kl_reverse,
use_reparametrization = NULL,
seed = NULL,
name = NULL
)function that takes a set of Tensor arguments
and returns a Tensor log-density. Given
q_sample <- surrogate_posterior$sample(sample_size), this
will be (in Python) called as target_log_prob_fn(q_sample) if q_sample is a list
or a tuple, target_log_prob_fn(**q_sample) if q_sample is a
dictionary, or target_log_prob_fn(q_sample) if q_sample is a Tensor.
It should support batched evaluation, i.e., should return a result of
shape [sample_size].
A tfp$distributions$Distribution
instance defining a variational posterior (could be a
tfp$distributions$JointDistribution). Crucially, the distribution's log_prob and
(if reparameterized) sample methods must directly invoke all ops
that generate gradients to the underlying variables. One way to ensure
this is to use tfp$util$DeferredTensor to represent any parameters
defined as transformations of unconstrained variables, so that the
transformations execute at runtime instead of at distribution creation.
integer number of Monte Carlo samples to use
in estimating the variational divergence. Larger values may stabilize
the optimization, but at higher cost per step in time and memory.
Default value: 1.
function representing a Csiszar f function in
in log-space. That is, discrepancy_fn(log(u)) = f(u), where f is
convex in u. Default value: vi_kl_reverse.
logical. When NULL (the default),
automatically set to: surrogate_posterior.reparameterization_type == tfp$distributions$FULLY_REPARAMETERIZED.
When TRUE uses the standard Monte-Carlo average. When FALSE uses the score-gradient trick. (See above for
details.) When FALSE, consider using csiszar_vimco.
integer seed for surrogate_posterior$sample.
name prefixed to Ops created by this function.
monte_carlo_variational_loss float-like Tensor Monte Carlo
approximation of the Csiszar f-Divergence.
This function defines divergences of the form
E_q[discrepancy_fn(log p(z) - log q(z))], sometimes known as
f-divergences.
In the special case discrepancy_fn(logu) == -logu (the default
vi_kl_reverse), this is the reverse Kullback-Liebler divergence
KL[q||p], whose negation applied to an unnormalized p is the widely-used
evidence lower bound (ELBO). Other cases of interest available under
tfp$vi include the forward KL[p||q] (given by vi_kl_forward(logu) == exp(logu) * logu),
total variation distance, Amari alpha-divergences, and more.
Csiszar f-divergences
A Csiszar function f is a convex function from R^+ (the positive reals)
to R. The Csiszar f-Divergence is given by:
D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ]
~= m**-1 sum_j^m f( p(x_j) / q(x_j) ),
where x_j ~iid q(X)
For example, f = lambda u: -log(u) recovers KL[q||p], while f = lambda u: u * log(u)
recovers the forward KL[p||q]. These and other functions are available in tfp$vi.
Tricks: Reparameterization and Score-Gradient
When q is "reparameterized", i.e., a diffeomorphic transformation of a
parameterless distribution (e.g., Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)),
we can swap gradient and expectation, i.e.,
grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n } where S_n=Avg{s_i}
and s_i = f(x_i), x_i ~iid q(X).
However, if q is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of unreparameterized distributions. In this circumstance using the Score-Gradient trick results in an unbiased gradient, i.e.,
grad[ E_q[f(X)] ] = grad[ int dx q(x) f(x) ] = int dx grad[ q(x) f(x) ] = int dx [ q'(x) f(x) + q(x) f'(x) ] = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ]
Unless q.reparameterization_type != tfd.FULLY_REPARAMETERIZED it is
usually preferable to set use_reparametrization = True.
Example Application: The Csiszar f-Divergence is a useful framework for variational inference. I.e., observe that,
f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] )
<= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ]
:= D_f[p(x, Z), q(Z | x)]
The inequality follows from the fact that the "perspective" of f, i.e.,
(s, t) |-> t f(s / t)), is convex in (s, t) when s/t in domain(f) and
t is a real. Since the above framework includes the popular Evidence Lower
BOund (ELBO) as a special case, i.e., f(u) = -log(u), we call this framework
"Evidence Divergence Bound Optimization" (EDBO).
Ali, Syed Mumtaz, and Samuel D. Silvey. "A general class of coefficients of divergence of one distribution from another." Journal of the Royal Statistical Society: Series B (Methodological) 28.1 (1966): 131-142.
Other vi-functions:
vi_amari_alpha(),
vi_arithmetic_geometric(),
vi_chi_square(),
vi_csiszar_vimco(),
vi_dual_csiszar_function(),
vi_fit_surrogate_posterior(),
vi_jeffreys(),
vi_jensen_shannon(),
vi_kl_forward(),
vi_kl_reverse(),
vi_log1p_abs(),
vi_modified_gan(),
vi_pearson(),
vi_squared_hellinger(),
vi_symmetrized_csiszar_function()