Type BatchReshape
Namespace tensorflow.contrib.distributions
Parent Distribution
Interfaces IBatchReshape
The Batch-Reshaping distribution. This "meta-distribution" reshapes the batch dimensions of another
distribution. #### Examples
Show Example
import tensorflow_probability as tfp
tfd = tfp.distributions dtype = np.float32
dims = 2
new_batch_shape = [1, 2, -1]
old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype)
mvn = tfd.MultivariateNormalDiag(scale_diag=scale)
reshape_mvn = tfd.BatchReshape(
distribution=mvn,
batch_shape=new_batch_shape,
validate_args=True) reshape_mvn.batch_shape
# ==> [1, 2, 3] x = reshape_mvn.sample(sample_shape=[4, 5])
x.shape
# ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] reshape_mvn.log_prob(x).shape
# ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape