LostTech.TensorFlow : API Documentation

Type ScaleTriL

Namespace tensorflow.contrib.distributions.bijectors

Parent Chain

Interfaces IScaleTriL

Transforms unconstrained vectors to TriL matrices with positive diagonal.

This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` followed by `tfb.TransformDiagonal`, and provided mostly as a convenience. The default setup is somewhat opinionated, using a Softplus transformation followed by a small shift (`1e-5`) which attempts to avoid numerical issues from zeros on the diagonal.

#### Examples
Show Example
import tensorflow_probability as tfp
            tfd = tfp.distributions
            tfb = tfp.bijectors 

b = tfb.ScaleTriL( diag_bijector=tfb.Exp(), diag_shift=None) b.forward(x=[0., 0., 0.]) # Result: [[1., 0.], # [0., 1.]] b.inverse(y=[[1., 0], [.5, 2]]) # Result: [log(2),.5, log(1)]

# Define a distribution over PSD matrices of shape `[3, 3]`, # with `1 + 2 + 3 = 6` degrees of freedom. dist = tfd.TransformedDistribution( tfd.Normal(tf.zeros(6), tf.ones(6)), tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()]))

# Using an identity transformation, ScaleTriL is equivalent to # tfb.FillTriangular. b = tfb.ScaleTriL( diag_bijector=tfb.Identity(), diag_shift=None)

# For greater control over initialization, one can manually encode # pre- and post- shifts inside of `diag_bijector`. b = tfb.ScaleTriL( diag_bijector=tfb.Chain([ tfb.AffineScalar(shift=1e-3), tfb.Softplus(), tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) # = log(expm1(1.)) = 0.5413 diag_shift=None)

Properties

Public properties

object bijectors get;

object bijectors_dyn get;

object dtype get;

object dtype_dyn get;

object forward_min_event_ndims get;

object forward_min_event_ndims_dyn get;

IList<object> graph_parents get;

object graph_parents_dyn get;

object inverse_min_event_ndims get;

object inverse_min_event_ndims_dyn get;

bool is_constant_jacobian get;

object is_constant_jacobian_dyn get;

object name get;

object name_dyn get;

object PythonObject get;

bool validate_args get;

object validate_args_dyn get;