Type BeamSearchDecoder
Namespace tensorflow.contrib.seq2seq
Parent Decoder
Interfaces IBeamSearchDecoder
BeamSearch sampling decoder. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via
tf.contrib.seq2seq.tile_batch
(NOT tf.tile
).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder. An example: ```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism,...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
``` Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoder to cover all inputs.
Methods
Properties
Public static methods
BeamSearchDecoder NewDyn(object cell, object embedding, object start_tokens, object end_token, object initial_state, object beam_width, object output_layer, ImplicitContainer<T> length_penalty_weight, ImplicitContainer<T> coverage_penalty_weight, ImplicitContainer<T> reorder_tensor_arrays)
Initialize the BeamSearchDecoder.
Parameters
-
object
cell - An `RNNCell` instance.
-
object
embedding - A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`.
-
object
start_tokens - `int32` vector shaped `[batch_size]`, the start tokens.
-
object
end_token - `int32` scalar, the token that marks end of decoding.
-
object
initial_state - A (possibly nested tuple of...) tensors and TensorArrays.
-
object
beam_width - Python integer, the number of beams.
-
object
output_layer - (Optional) An instance of
tf.keras.layers.Layer
, i.e.,tf.keras.layers.Dense
. Optional layer to apply to the RNN output prior to storing the result or sampling. -
ImplicitContainer<T>
length_penalty_weight - Float weight to penalize length. Disabled with 0.0.
-
ImplicitContainer<T>
coverage_penalty_weight - Float weight to penalize the coverage of source sentence. Disabled with 0.0.
-
ImplicitContainer<T>
reorder_tensor_arrays - If `True`, `TensorArray`s' elements within the cell state will be reordered according to the beam search path. If the `TensorArray` can be reordered, the stacked form will be returned. Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering.
Public properties
object batch_size get;
object batch_size_dyn get;
object output_dtype get;
A (possibly nested tuple of...) dtype[s].
object output_dtype_dyn get;
A (possibly nested tuple of...) dtype[s].