Type AttentionWrapper
Namespace tensorflow.contrib.seq2seq
Parent RNNCell
Interfaces IAttentionWrapper
Wraps another `RNNCell` with attention.
Methods
Properties
- activity_regularizer
- activity_regularizer_dyn
- built
- dtype
- dtype_dyn
- dynamic
- dynamic_dyn
- graph
- graph_dyn
- inbound_nodes
- inbound_nodes_dyn
- input
- input_dyn
- input_mask
- input_mask_dyn
- input_shape
- input_shape_dyn
- input_spec
- input_spec_dyn
- losses
- losses_dyn
- metrics
- metrics_dyn
- name
- name_dyn
- name_scope
- name_scope_dyn
- non_trainable_variables
- non_trainable_variables_dyn
- non_trainable_weights
- non_trainable_weights_dyn
- outbound_nodes
- outbound_nodes_dyn
- output
- output_dyn
- output_mask
- output_mask_dyn
- output_shape
- output_shape_dyn
- output_size
- output_size_dyn
- PythonObject
- rnncell_scope
- scope_name
- scope_name_dyn
- state_size
- state_size_dyn
- stateful
- submodules
- submodules_dyn
- supports_masking
- trainable
- trainable_dyn
- trainable_variables
- trainable_variables_dyn
- trainable_weights
- trainable_weights_dyn
- updates
- updates_dyn
- variables
- variables_dyn
- weights
- weights_dyn
Public static methods
AttentionWrapper NewDyn(object cell, object attention_mechanism, object attention_layer_size, ImplicitContainer<T> alignment_history, object cell_input_fn, ImplicitContainer<T> output_attention, object initial_cell_state, object name, object attention_layer, object attention_fn, object dtype)
Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via
tf.contrib.seq2seq.tile_batch
(NOT tf.tile
).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder. An example: ```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism,...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Parameters
-
object
cell - An instance of `RNNCell`.
-
object
attention_mechanism - A list of `AttentionMechanism` instances or a single instance.
-
object
attention_layer_size - A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If attention_layer is set, this must be None. If attention_fn is set, it must guaranteed that the outputs of attention_fn also meet the above requirements.
-
ImplicitContainer<T>
alignment_history - Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`).
-
object
cell_input_fn - (optional) A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
-
ImplicitContainer<T>
output_attention - Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is the output of `cell`. This is the behavior of Bhadanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output.
-
object
initial_cell_state - The initial state value to use for the cell when the user calls `zero_state()`. Note that if this value is provided now, and the user uses a `batch_size` argument of `zero_state` which does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed.
-
object
name - Name to use when creating ops.
-
object
attention_layer - A list of `tf.compat.v1.layers.Layer` instances or a single `tf.compat.v1.layers.Layer` instance taking the context and cell output as inputs to generate attention at each time step. If None (default), use the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None.
-
object
attention_fn - An optional callable function that allows users to provide their own customized attention function, which takes input (attention_mechanism, cell_output, attention_state, attention_layer) and outputs (attention, alignments, next_attention_state). If provided, the attention_layer_size should be the size of the outputs of attention_fn.
-
object
dtype - The cell dtype
Public properties
PythonFunctionContainer activity_regularizer get; set;
object activity_regularizer_dyn get; set;
bool built get; set;
object dtype get;
object dtype_dyn get;
bool dynamic get;
object dynamic_dyn get;
object graph get;
object graph_dyn get;
IList<Node> inbound_nodes get;
object inbound_nodes_dyn get;
IList<object> input get;
object input_dyn get;
object input_mask get;
object input_mask_dyn get;
IList<object> input_shape get;
object input_shape_dyn get;
object input_spec get; set;
object input_spec_dyn get; set;
IList<object> losses get;
object losses_dyn get;
IList<object> metrics get;
object metrics_dyn get;
object name get;
object name_dyn get;
object name_scope get;
object name_scope_dyn get;
IList<object> non_trainable_variables get;
object non_trainable_variables_dyn get;
IList<object> non_trainable_weights get;
object non_trainable_weights_dyn get;
IList<object> outbound_nodes get;
object outbound_nodes_dyn get;
IList<object> output get;
object output_dyn get;
object output_mask get;
object output_mask_dyn get;
object output_shape get;
object output_shape_dyn get;
object output_size get;
Integer or TensorShape: size of outputs produced by this cell.
object output_size_dyn get;
Integer or TensorShape: size of outputs produced by this cell.
object PythonObject get;
object rnncell_scope get; set;
string scope_name get;
object scope_name_dyn get;
object state_size get;
The `state_size` property of `AttentionWrapper`.
object state_size_dyn get;
The `state_size` property of `AttentionWrapper`.