Type QuantizedDistribution
Namespace tensorflow.contrib.distributions
Parent Distribution
Interfaces IQuantizedDistribution
Distribution representing the quantization `Y = ceiling(X)`. #### Definition in Terms of Sampling ```
1. Draw X
2. Set Y <-- ceiling(X)
3. If Y < low, reset Y <-- low
4. If Y > high, reset Y <-- high
5. Return Y
``` #### Definition in Terms of the Probability Mass Function Given scalar random variable `X`, we define a discrete random variable `Y`
supported on the integers as follows: ```
P[Y = j] := P[X <= low], if j == low,
:= P[X > high - 1], j == high,
:= 0, if j < low or j > high,
:= P[j - 1 < X <= j], all other j.
``` Conceptually, without cutoffs, the quantization process partitions the real
line `R` into half open intervals, and identifies an integer `j` with the
right endpoints: ```
R =... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4]...
j =... -1 0 1 2 3 4 ...
``` `P[Y = j]` is the mass of `X` within the `jth` interval.
If `low = 0`, and `high = 2`, then the intervals are redrawn
and `j` is re-assigned: ```
R = (-infty, 0](0, 1](1, infty)
j = 0 1 2
``` `P[Y = j]` is still the mass of `X` within the `jth` interval. #### Examples We illustrate a mixture of discretized logistic distributions
[(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
`P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
The lowest value has probability `P(X <= 0.5)` and the highest value has
probability `P(2**16 - 1.5 < X)`. Below we assume a `wavenet` function. It takes as `input` right-shifted audio
samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
`scale` parameter belonging to the logistic distribution, and a `logits`
parameter determining the unnormalized probability of that component.
After instantiating `mixture_dist`, we illustrate maximum likelihood by
calculating its log-probability of audio samples as `target` and optimizing. #### References [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
PixelCNN++: Improving the PixelCNN with discretized logistic mixture
likelihood and other modifications.
_International Conference on Learning Representations_, 2017.
https://arxiv.org/abs/1701.05517
[2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
https://arxiv.org/abs/1711.10433
Show Example
import tensorflow_probability as tfp tfd = tfp.distributions tfb = tfp.bijectors net = wavenet(inputs) loc, unconstrained_scale, logits = tf.split(net, num_or_size_splits=3, axis=-1) scale = tf.nn.softplus(unconstrained_scale) # Form mixture of discretized logistic distributions. Note we shift the # logistic distribution by -0.5. This lets the quantization capture "rounding" # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. discretized_logistic_dist = tfd.QuantizedDistribution( distribution=tfd.TransformedDistribution( distribution=tfd.Logistic(loc=loc, scale=scale), bijector=tfb.AffineScalar(shift=-0.5)), low=0., high=2**16 - 1.) mixture_dist = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=logits), components_distribution=discretized_logistic_dist) neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
Properties
- allow_nan_stats
- allow_nan_stats_dyn
- batch_shape
- batch_shape_dyn
- distribution
- distribution_dyn
- dtype
- dtype_dyn
- event_shape
- event_shape_dyn
- high
- high_dyn
- low
- low_dyn
- name
- name_dyn
- parameters
- parameters_dyn
- PythonObject
- reparameterization_type
- reparameterization_type_dyn
- validate_args
- validate_args_dyn
Public properties
object allow_nan_stats get;
object allow_nan_stats_dyn get;
TensorShape batch_shape get;
object batch_shape_dyn get;
object distribution get;
Base distribution, p(x).
object distribution_dyn get;
Base distribution, p(x).
object dtype get;
object dtype_dyn get;
TensorShape event_shape get;
object event_shape_dyn get;
object high get;
Highest value that quantization returns.
object high_dyn get;
Highest value that quantization returns.
object low get;
Lowest value that quantization returns.
object low_dyn get;
Lowest value that quantization returns.