LostTech.TensorFlow : API Documentation

Type BatchReshape

Namespace tensorflow.contrib.distributions

Parent Distribution

Interfaces IBatchReshape

The Batch-Reshaping distribution.

This "meta-distribution" reshapes the batch dimensions of another distribution.

#### Examples
Show Example
import tensorflow_probability as tfp
            tfd = tfp.distributions 

dtype = np.float32 dims = 2 new_batch_shape = [1, 2, -1] old_batch_shape = [6]

scale = np.ones(old_batch_shape + [dims], dtype) mvn = tfd.MultivariateNormalDiag(scale_diag=scale) reshape_mvn = tfd.BatchReshape( distribution=mvn, batch_shape=new_batch_shape, validate_args=True)

reshape_mvn.batch_shape # ==> [1, 2, 3]

x = reshape_mvn.sample(sample_shape=[4, 5]) x.shape # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims]

reshape_mvn.log_prob(x).shape # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape

Properties

Public properties

object allow_nan_stats get;

object allow_nan_stats_dyn get;

TensorShape batch_shape get;

object batch_shape_dyn get;

object distribution get;

object distribution_dyn get;

object dtype get;

object dtype_dyn get;

TensorShape event_shape get;

object event_shape_dyn get;

string name get;

object name_dyn get;

IDictionary<object, object> parameters get;

object parameters_dyn get;

object PythonObject get;

object reparameterization_type get;

object reparameterization_type_dyn get;

object validate_args get;

object validate_args_dyn get;