A TransformedDistribution models p(y)
given a base distribution p(x)
,
and a deterministic, invertible, differentiable transform,Y = g(X)
. The
transform is typically an instance of the Bijector class and the base
distribution is typically an instance of the Distribution class.
tfd_transformed_distribution(
distribution,
bijector,
batch_shape = NULL,
event_shape = NULL,
kwargs_split_fn = NULL,
validate_args = FALSE,
parameters = NULL,
name = NULL
)
a distribution instance.
The base distribution instance to transform. Typically an instance of Distribution.
The object responsible for calculating the transformation. Typically an instance of Bijector.
integer vector Tensor which overrides distribution batch_shape; valid only if distribution.is_scalar_batch().
integer vector Tensor which overrides distribution event_shape; valid only if distribution.is_scalar_event().
Python callable
which takes a kwargs dict
and returns
a tuple of kwargs dict
s for each of the distribution
and bijector
parameters respectively. Default value: _default_kwargs_split_fn
(i.e.,
lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {}))
)
Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.
Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations.
The name for ops managed by the distribution. Default value: bijector.name + distribution.name.
A Bijector
is expected to implement the following functions:
forward
,
inverse
,
inverse_log_det_jacobian
.
The semantics of these functions are outlined in the Bijector
documentation.
We now describe how a TransformedDistribution
alters the input/outputs of a
Distribution
associated with a random variable (rv) X
.
Write cdf(Y=y)
for an absolutely continuous cumulative distribution function
of random variable Y
; write the probability density function
pdf(Y=y) := d^k / (dy_1,...,dy_k) cdf(Y=y)
for its derivative wrt to Y
evaluated at
y
. Assume that Y = g(X)
where g
is a deterministic diffeomorphism,
i.e., a non-random, continuous, differentiable, and invertible function.
Write the inverse of g
as X = g^{-1}(Y)
and (J o g)(x)
for the Jacobian
of g
evaluated at x
.
A TransformedDistribution
implements the following operations:
sample
Mathematically: Y = g(X)
Programmatically: bijector.forward(distribution.sample(...))
log_prob
Mathematically: (log o pdf)(Y=y) = (log o pdf o g^{-1})(y) + (log o abs o det o J o g^{-1})(y)
Programmatically: (distribution.log_prob(bijector.inverse(y)) + bijector.inverse_log_det_jacobian(y))
log_cdf
Mathematically: (log o cdf)(Y=y) = (log o cdf o g^{-1})(y)
Programmatically: distribution.log_cdf(bijector.inverse(x))
and similarly for: cdf
, prob
, log_survival_function
, survival_function
.
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named_auto_batched()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()