The MixtureSameFamily
distribution implements a (batch of) mixture
distribution where all components are from different parameterizations of the
same distribution type. It is parameterized by a Categorical
"selecting
distribution" (over k
components) and a components distribution, i.e., a
Distribution
with a rightmost batch shape (equal to [k]
) which indexes
each (batch of) component.
tfd_mixture_same_family(
mixture_distribution,
components_distribution,
reparameterize = FALSE,
validate_args = FALSE,
allow_nan_stats = TRUE,
name = "MixtureSameFamily"
)
a distribution instance.
tfp$distributions$Categorical
-like instance.
Manages the probability of selecting components. The number of
categories must match the rightmost batch dimension of the
components_distribution
. Must have either scalar batch_shape
or
batch_shape
matching components_distribution$batch_shape[:-1]
.
tfp$distributions$Distribution
-like instance.
Right-most batch dimension indexes components.
Logical, default FALSE
. Whether to reparameterize
samples of the distribution using implicit reparameterization gradients
(Figurnov et al., 2018). The gradients for the mixture logits are
equivalent to the ones described by (Graves, 2016). The gradients
for the components parameters are also computed using implicit
reparameterization (as opposed to ancestral sampling), meaning that
all components are updated every step.
Only works when:
(1) components_distribution is fully reparameterized;
(2) components_distribution is either a scalar distribution or
fully factorized (tfd.Independent applied to a scalar distribution);
(3) batch shape has a known rank.
Experimental, may be slow and produce infs/NaNs.
Logical, default FALSE. When TRUE distribution parameters are checked for validity despite possibly degrading runtime performance. When FALSE invalid inputs may silently render incorrect outputs. Default value: FALSE.
Logical, default TRUE. When TRUE, statistics (e.g., mean, mode, variance) use the value NaN to indicate the result is undefined. When FALSE, an exception is raised if one or more of the statistic's batch members are undefined.
name prefixed to Ops created by this class.
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named_auto_batched()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_pixel_cnn()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()