This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
to sample from a series of distributions that slowly interpolates between
an initial "proposal" distribution:
exp(proposal_log_prob_fn(x) - proposal_log_normalizer)
and the target distribution:
exp(target_log_prob_fn(x) - target_log_normalizer)
,
accumulating importance weights along the way. The product of these
importance weights gives an unbiased estimate of the ratio of the
normalizing constants of the initial distribution and the target
distribution:
E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)
.
mcmc_sample_annealed_importance_chain(
num_steps,
proposal_log_prob_fn,
target_log_prob_fn,
current_state,
make_kernel_fn,
parallel_iterations = 10,
name = NULL
)
list of
next_state
(Tensor
or Python list of Tensor
s representing the
state(s) of the Markov chain(s) at the final iteration. Has same shape as
input current_state
),
ais_weights
(Tensor with the estimated weight(s). Has shape matching
target_log_prob_fn(current_state)
), and
kernel_results
(collections.namedtuple
of internal calculations used to
advance the chain).
Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator.
function that returns the log density of the initial distribution.
function which takes an argument like
current_state
and returns its
(possibly unnormalized) log-density under the target distribution.
Tensor
or list
of Tensor
s representing the
current state(s) of the Markov chain(s). The first r
dimensions index
independent chains, r
= tf$rank(target_log_prob_fn(current_state))
.
function which returns a TransitionKernel
-like
object. Must take one argument representing the TransitionKernel
's
target_log_prob_fn
. The target_log_prob_fn
argument represents the
TransitionKernel
's target log distribution. Note:
sample_annealed_importance_chain
creates a new target_log_prob_fn
which is an interpolation between the supplied target_log_prob_fn
and
proposal_log_prob_fn
; it is this interpolated function which is used as
an argument to make_kernel_fn
.
The number of iterations allowed to run in parallel.
It must be a positive integer. See tf$while_loop
for more details.
string prefixed to Ops created by this function.
Default value: NULL
(i.e., "sample_annealed_importance_chain").
Note: When running in graph mode, proposal_log_prob_fn
and
target_log_prob_fn
are called exactly three times (although this may be
reduced to two times in the future).
For an example how to use see mcmc_sample_chain()
.
Other mcmc_functions:
mcmc_effective_sample_size()
,
mcmc_potential_scale_reduction()
,
mcmc_sample_chain()
,
mcmc_sample_halton_sequence()