Pixel CNN++ (Salimans et al., 2017) models a distribution over image
data, parameterized by a neural network. It builds on Pixel CNN and
Conditional Pixel CNN, as originally proposed by
(van den Oord et al., 2016).
The model expresses the joint distribution over pixels as
the product of conditional distributions:
p(x|h) = prod{ p(x[i] | x[0:i], h) : i=0, ..., d }
, in which
p(x[i] | x[0:i], h) : i=0, ..., d
is the
probability of the i
-th pixel conditional on the pixels that preceded it in
raster order (color channels in RGB order, then left to right, then top to
bottom). h
is optional additional data on which to condition the image
distribution, such as class labels or VAE embeddings. The Pixel CNN++
network enforces the dependency structure among pixels by applying a mask to
the kernels of the convolutional layers that ensures that the values for each
pixel depend only on other pixels up and to the left.
Pixel values are modeled with a mixture of quantized logistic distributions,
which can take on a set of distinct integer values (e.g. between 0 and 255
for an 8-bit image).
Color intensity v
of each pixel is modeled as:
v ~ sum{q[i] * quantized_logistic(loc[i], scale[i]) : i = 0, ..., k }
,
in which k
is the number of mixture components and the q[i]
are the
Categorical probabilities over the components.
tfd_pixel_cnn(
image_shape,
conditional_shape = NULL,
num_resnet = 5,
num_hierarchies = 3,
num_filters = 160,
num_logistic_mix = 10,
receptive_field_dims = c(3, 3),
dropout_p = 0.5,
resnet_activation = "concat_elu",
use_weight_norm = TRUE,
use_data_init = TRUE,
high = 255,
low = 0,
dtype = tf$float32,
name = "PixelCNN"
)
3D TensorShape
or tuple for the [height, width, channels]
dimensions of the image.
TensorShape
or tuple for the shape of the
conditional input, or NULL
if there is no conditional input.
integer
, the number of layers (shown in Figure 2 of https://arxiv.org/abs/1606.05328)
within each highest-level block of Figure 2 of https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.
integer
, the number of hightest-level blocks (separated by
expansions/contractions of dimensions in Figure 2 of https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.)
integer
, the number of convolutional filters.
integer
, number of components in the logistic mixture
distribution.
tuple
, height and width in pixels of the receptive
field of the convolutional layers above and to the left of a given
pixel. The width (second element of the tuple) should be odd. Figure 1
(middle) of https://arxiv.org/abs/1606.05328 shows a receptive field of (3, 5)
(the row containing the current pixel is included in the height).
The default of (3, 3) was used to produce the results in https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf.
float
, the dropout probability. Should be between 0 and 1.
string
, the type of activation to use in the resnet blocks.
May be 'concat_elu', 'elu', or 'relu'.
logical
, if TRUE
then use weight normalization (works
only in Eager mode).
logical
, if TRUE
then use data-dependent initialization
(has no effect if use_weight_norm
is FALSE
).
integer
, the maximum value of the input data (255 for an 8-bit image).
integer
, the minimum value of the input data.
Data type of the Distribution
.
string
, the name of the Distribution
.
a distribution instance.
For usage examples see e.g. tfd_sample()
, tfd_log_prob()
, tfd_mean()
.
Other distributions:
tfd_autoregressive()
,
tfd_batch_reshape()
,
tfd_bates()
,
tfd_bernoulli()
,
tfd_beta_binomial()
,
tfd_beta()
,
tfd_binomial()
,
tfd_categorical()
,
tfd_cauchy()
,
tfd_chi2()
,
tfd_chi()
,
tfd_cholesky_lkj()
,
tfd_continuous_bernoulli()
,
tfd_deterministic()
,
tfd_dirichlet_multinomial()
,
tfd_dirichlet()
,
tfd_empirical()
,
tfd_exp_gamma()
,
tfd_exp_inverse_gamma()
,
tfd_exponential()
,
tfd_gamma_gamma()
,
tfd_gamma()
,
tfd_gaussian_process_regression_model()
,
tfd_gaussian_process()
,
tfd_generalized_normal()
,
tfd_geometric()
,
tfd_gumbel()
,
tfd_half_cauchy()
,
tfd_half_normal()
,
tfd_hidden_markov_model()
,
tfd_horseshoe()
,
tfd_independent()
,
tfd_inverse_gamma()
,
tfd_inverse_gaussian()
,
tfd_johnson_s_u()
,
tfd_joint_distribution_named_auto_batched()
,
tfd_joint_distribution_named()
,
tfd_joint_distribution_sequential_auto_batched()
,
tfd_joint_distribution_sequential()
,
tfd_kumaraswamy()
,
tfd_laplace()
,
tfd_linear_gaussian_state_space_model()
,
tfd_lkj()
,
tfd_log_logistic()
,
tfd_log_normal()
,
tfd_logistic()
,
tfd_mixture_same_family()
,
tfd_mixture()
,
tfd_multinomial()
,
tfd_multivariate_normal_diag_plus_low_rank()
,
tfd_multivariate_normal_diag()
,
tfd_multivariate_normal_full_covariance()
,
tfd_multivariate_normal_linear_operator()
,
tfd_multivariate_normal_tri_l()
,
tfd_multivariate_student_t_linear_operator()
,
tfd_negative_binomial()
,
tfd_normal()
,
tfd_one_hot_categorical()
,
tfd_pareto()
,
tfd_poisson_log_normal_quadrature_compound()
,
tfd_poisson()
,
tfd_power_spherical()
,
tfd_probit_bernoulli()
,
tfd_quantized()
,
tfd_relaxed_bernoulli()
,
tfd_relaxed_one_hot_categorical()
,
tfd_sample_distribution()
,
tfd_sinh_arcsinh()
,
tfd_skellam()
,
tfd_spherical_uniform()
,
tfd_student_t_process()
,
tfd_student_t()
,
tfd_transformed_distribution()
,
tfd_triangular()
,
tfd_truncated_cauchy()
,
tfd_truncated_normal()
,
tfd_uniform()
,
tfd_variational_gaussian_process()
,
tfd_vector_diffeomixture()
,
tfd_vector_exponential_diag()
,
tfd_vector_exponential_linear_operator()
,
tfd_vector_laplace_diag()
,
tfd_vector_laplace_linear_operator()
,
tfd_vector_sinh_arcsinh_diag()
,
tfd_von_mises_fisher()
,
tfd_von_mises()
,
tfd_weibull()
,
tfd_wishart_linear_operator()
,
tfd_wishart_tri_l()
,
tfd_wishart()
,
tfd_zipf()