This distribution is parameterized by probs
, a (batch of) parameters
taking values in (0, 1)
. Note that, unlike in the Bernoulli case, probs
does not correspond to a probability, but the same name is used due to the
similarity with the Bernoulli.
tfd_continuous_bernoulli(
logits = NULL,
probs = NULL,
dtype = tf$float32,
validate_args = FALSE,
allow_nan_stats = TRUE,
name = "ContinuousBernoulli"
)
a distribution instance.
An N-D Tensor
. Each entry in the Tensor
parameterizes
an independent continuous Bernoulli distribution with parameter
sigmoid(logits). Only one of logits
or probs
should be passed
in. Note that this does not correspond to the log-odds as in the
Bernoulli case.
An N-D Tensor
representing the parameter of a continuous
Bernoulli. Each entry in the Tensor
parameterizes an independent
continuous Bernoulli distribution. Only one of logits
or probs
should be passed in. Note that this also does not correspond to a
probability as in the Bernoulli case.
The type of the event samples. Default: float32
.
Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.
Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined.
name prefixed to Ops created by this class.
Mathematical Details
The continuous Bernoulli is a distribution over the interval [0, 1]
,
parameterized by probs
in (0, 1)
.
The probability density function (pdf) is,
pdf(x; probs) = probs**x * (1 - probs)**(1 - x) * C(probs)
C(probs) = (2 * atanh(1 - 2 * probs) / (1 - 2 * probs) if probs != 0.5 else 2.)
While the normalizing constant C(probs)
is a continuous function of probs
(even at probs = 0.5
), computing it at values close to 0.5 can result in
numerical instabilities due to 0/0 errors. A Taylor approximation of
C(probs)
is thus used for values of probs
in a small interval [lims[0], lims[1]]
around 0.5. For more details,
see Loaiza-Ganem and Cunningham (2019).
NOTE: Unlike the Bernoulli, numerical instabilities can happen for probs
very close to 0 or 1. Current implementation allows any value in (0, 1)
,
but this could be changed to (1e-6, 1-1e-6)
to avoid these issues.
Loaiza-Ganem G and Cunningham JP. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. NeurIPS2019. https://arxiv.org/abs/1907.06845
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named_auto_batched()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()