LostTech.TensorFlow : API Documentation

Type MaskedBasicLSTMCell

Namespace tensorflow.contrib.model_pruning

Parent BasicLSTMCell

Interfaces IMaskedBasicLSTMCell

Basic LSTM recurrent network cell with pruning.

Overrides the call method of tensorflow BasicLSTMCell and injects the weight masks

The implementation is based on: http://arxiv.org/abs/1409.2329.

We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.

For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` that follows.

Methods

Properties

Public static methods

MaskedBasicLSTMCell NewDyn(object num_units, ImplicitContainer<T> forget_bias, ImplicitContainer<T> state_is_tuple, object activation, object reuse, object name)

Initialize the basic LSTM cell with pruning.
Parameters
object num_units
int, The number of units in the LSTM cell.
ImplicitContainer<T> forget_bias
float, The bias added to forget gates (see above). Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
ImplicitContainer<T> state_is_tuple
If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated.
object activation
Activation function of the inner states. Default: `tanh`.
object reuse
(optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised.
object name
String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.

When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMCell instead.

Public properties

PythonFunctionContainer activity_regularizer get; set;

object activity_regularizer_dyn get; set;

bool built get; set;

object dtype get;

object dtype_dyn get;

bool dynamic get;

object dynamic_dyn get;

object graph get;

object graph_dyn get;

IList<Node> inbound_nodes get;

object inbound_nodes_dyn get;

IList<object> input get;

object input_dyn get;

object input_mask get;

object input_mask_dyn get;

IList<object> input_shape get;

object input_shape_dyn get;

InputSpec input_spec get; set;

object input_spec_dyn get; set;

IList<object> losses get;

object losses_dyn get;

IList<object> metrics get;

object metrics_dyn get;

object name get;

object name_dyn get;

object name_scope get;

object name_scope_dyn get;

IList<object> non_trainable_variables get;

object non_trainable_variables_dyn get;

IList<object> non_trainable_weights get;

object non_trainable_weights_dyn get;

IList<object> outbound_nodes get;

object outbound_nodes_dyn get;

IList<object> output get;

object output_dyn get;

object output_mask get;

object output_mask_dyn get;

object output_shape get;

object output_shape_dyn get;

object output_size get;

object output_size_dyn get;

object PythonObject get;

object rnncell_scope get; set;

string scope_name get;

object scope_name_dyn get;

object state_size get;

object state_size_dyn get;

bool stateful get; set;

ValueTuple<object> submodules get;

object submodules_dyn get;

bool supports_masking get; set;

bool trainable get; set;

object trainable_dyn get; set;

object trainable_variables get;

object trainable_variables_dyn get;

IList<object> trainable_weights get;

object trainable_weights_dyn get;

IList<object> updates get;

object updates_dyn get;

object variables get;

object variables_dyn get;

IList<object> weights get;

object weights_dyn get;