Type BatchNormalization
Namespace tensorflow.contrib.distributions.bijectors
Parent Bijector
Interfaces IBatchNormalization
Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a
data distribution. This can be used to stabilize training of normalizing
flows ([Papamakarios et al., 2016][3]; [Dinh et al., 2017][2]) When training Deep Neural Networks (DNNs), it is common practice to
normalize or whiten features by shifting them to have zero mean and
scaling them to have unit variance. The `inverse()` method of the `BatchNormalization` bijector, which is used in
the log-likelihood computation of data samples, implements the normalization
procedure (shift-and-scale) using the mean and standard deviation of the
current minibatch. Conversely, the `forward()` method of the bijector de-normalizes samples (e.g.
`X*std(Y) + mean(Y)` with the running-average mean and standard deviation
computed at training-time. De-normalization is useful for sampling.
During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not
guaranteed to be inverses of each other because `inverse(y)` uses statistics
of the current minibatch, while `forward(x)` uses running-average statistics
accumulated from training. In other words,
`BatchNorm.inverse(BatchNorm.forward(...))` and
`BatchNorm.forward(BatchNorm.inverse(...))` will be identical when
`training=False` but may be different when `training=True`. #### References [1]: Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating
Deep Network Training by Reducing Internal Covariate Shift. In
_International Conference on Machine Learning_, 2015.
https://arxiv.org/abs/1502.03167 [2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation
using Real NVP. In _International Conference on Learning
Representations_, 2017. https://arxiv.org/abs/1605.08803 [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked
Autoregressive Flow for Density Estimation. In _Neural Information
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
Show Example
dist = tfd.TransformedDistribution( distribution=tfd.Normal()), bijector=tfb.BatchNorm()) y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100) # ~ N(1, 2) x = dist.bijector.inverse(y) # ~ N(0, 1) y = dist.sample() # ~ N(1, 2)