Type BatchReshape
Namespace tensorflow.contrib.distributions
Parent Distribution
Interfaces IBatchReshape
The Batch-Reshaping distribution. This "meta-distribution" reshapes the batch dimensions of another
distribution. #### Examples
Show Example
import tensorflow_probability as tfp tfd = tfp.distributions dtype = np.float32 dims = 2 new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) mvn = tfd.MultivariateNormalDiag(scale_diag=scale) reshape_mvn = tfd.BatchReshape( distribution=mvn, batch_shape=new_batch_shape, validate_args=True) reshape_mvn.batch_shape # ==> [1, 2, 3] x = reshape_mvn.sample(sample_shape=[4, 5]) x.shape # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] reshape_mvn.log_prob(x).shape # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape