Type OneHotCategorical
Namespace tensorflow.contrib.distributions
Parent Distribution
Interfaces IOneHotCategorical
OneHotCategorical distribution. The categorical distribution is parameterized by the log-probabilities
of a set of classes. The difference between OneHotCategorical and Categorical
distributions is that OneHotCategorical is a discrete distribution over
one-hot bit vectors whereas Categorical is a discrete distribution over
positive integers. OneHotCategorical is equivalent to Categorical except
Categorical has event_dim=() while OneHotCategorical has event_dim=K, where
K is the number of classes. This class provides methods to create indexed batches of OneHotCategorical
distributions. If the provided `logits` or `probs` is rank 2 or higher, for
every fixed set of leading dimensions, the last dimension represents one
single OneHotCategorical distribution. When calling distribution
functions (e.g. `dist.prob(x)`), `logits` and `x` are broadcast to the
same shape (if possible). In all cases, the last dimension of `logits,x`
represents single OneHotCategorical distributions. #### Examples Creates a 3-class distribution, with the 2nd class, the most likely to be
drawn from.
Creates a 3-class distribution, with the 2nd class the most likely to be
drawn from, using logits.
Creates a 3-class distribution, with the 3rd class is most likely to be drawn.
Show Example
p = [0.1, 0.5, 0.4] dist = OneHotCategorical(probs=p)
Properties
- allow_nan_stats
- allow_nan_stats_dyn
- batch_shape
- batch_shape_dyn
- dtype
- dtype_dyn
- event_shape
- event_shape_dyn
- event_size
- event_size_dyn
- logits
- logits_dyn
- name
- name_dyn
- parameters
- parameters_dyn
- probs
- probs_dyn
- PythonObject
- reparameterization_type
- reparameterization_type_dyn
- validate_args
- validate_args_dyn
Public properties
object allow_nan_stats get;
object allow_nan_stats_dyn get;
TensorShape batch_shape get;
object batch_shape_dyn get;
object dtype get;
object dtype_dyn get;
TensorShape event_shape get;
object event_shape_dyn get;
Tensor event_size get;
Scalar `int32` tensor: the number of classes.
object event_size_dyn get;
Scalar `int32` tensor: the number of classes.
object logits get;
Vector of coordinatewise logits.
object logits_dyn get;
Vector of coordinatewise logits.
string name get;
object name_dyn get;
IDictionary<object, object> parameters get;
object parameters_dyn get;
Tensor probs get;
Vector of coordinatewise probabilities.
object probs_dyn get;
Vector of coordinatewise probabilities.