Type Strategy
Namespace tensorflow.distribute
Parent Strategy
Interfaces IStrategy
A list of devices with a state & compute distribution policy. See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
for overview and examples. Note: Not all
tf.distribute.Strategy
implementations currently support
TensorFlow's partitioned variables (where a single variable is split across
multiple devices) at this time.
Methods
- experimental_run_v2
- make_dataset_iterator
- make_dataset_iterator_dyn
- make_input_fn_iterator
- make_input_fn_iterator_dyn
Properties
Public instance methods
object experimental_run_v2(object fn, ImplicitContainer<T> args, IDictionary<string, object> kwargs)
See base class.
object make_dataset_iterator(object dataset)
Makes an iterator for input provided via `dataset`. DEPRECATED: This method is not available in TF 2.x. Data from the given dataset will be distributed evenly across all the
compute replicas. We will assume that the input dataset is batched by the
global batch size. With this assumption, we will make a best effort to
divide each batch across all the replicas (one or more workers).
If this effort fails, an error will be thrown, and the user should instead
use `make_input_fn_iterator` which provides more control to the user, and
does not try to divide a batch across replicas. The user could also use `make_input_fn_iterator` if they want to
customize which input is fed to which replica/worker etc.
Parameters
-
object
dataset tf.data.Dataset
that will be distributed evenly across all replicas.
Returns
-
object
- An `tf.distribute.InputIterator` which returns inputs for each step of the computation. User should call `initialize` on the returned iterator.
object make_dataset_iterator_dyn(object dataset)
Makes an iterator for input provided via `dataset`. DEPRECATED: This method is not available in TF 2.x. Data from the given dataset will be distributed evenly across all the
compute replicas. We will assume that the input dataset is batched by the
global batch size. With this assumption, we will make a best effort to
divide each batch across all the replicas (one or more workers).
If this effort fails, an error will be thrown, and the user should instead
use `make_input_fn_iterator` which provides more control to the user, and
does not try to divide a batch across replicas. The user could also use `make_input_fn_iterator` if they want to
customize which input is fed to which replica/worker etc.
Parameters
-
object
dataset tf.data.Dataset
that will be distributed evenly across all replicas.
Returns
-
object
- An `tf.distribute.InputIterator` which returns inputs for each step of the computation. User should call `initialize` on the returned iterator.
object make_input_fn_iterator(PythonFunctionContainer input_fn, ImplicitContainer<T> replication_mode)
Returns an iterator split across replicas created from an input function. DEPRECATED: This method is not available in TF 2.x. The `input_fn` should take an
tf.distribute.InputContext
object where
information about batching and input sharding can be accessed: ```
def input_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with strategy.scope():
iterator = strategy.make_input_fn_iterator(input_fn)
replica_results = strategy.experimental_run(replica_fn, iterator)
``` The tf.data.Dataset
returned by `input_fn` should have a per-replica
batch size, which may be computed using
`input_context.get_per_replica_batch_size`.
Parameters
-
PythonFunctionContainer
input_fn - A function taking a
tf.distribute.InputContext
object and returning atf.data.Dataset
. -
ImplicitContainer<T>
replication_mode - an enum value of
tf.distribute.InputReplicationMode
. Only `PER_WORKER` is supported currently, which means there will be a single call to `input_fn` per worker. Replicas will dequeue from the localtf.data.Dataset
on their worker.
Returns
-
object
- An iterator object that should first be `.initialize()`-ed. It may then either be passed to `strategy.experimental_run()` or you can `iterator.get_next()` to get the next value to pass to `strategy.extended.call_for_each_replica()`.
object make_input_fn_iterator_dyn(object input_fn, ImplicitContainer<T> replication_mode)
Returns an iterator split across replicas created from an input function. DEPRECATED: This method is not available in TF 2.x. The `input_fn` should take an
tf.distribute.InputContext
object where
information about batching and input sharding can be accessed: ```
def input_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with strategy.scope():
iterator = strategy.make_input_fn_iterator(input_fn)
replica_results = strategy.experimental_run(replica_fn, iterator)
``` The tf.data.Dataset
returned by `input_fn` should have a per-replica
batch size, which may be computed using
`input_context.get_per_replica_batch_size`.
Parameters
-
object
input_fn - A function taking a
tf.distribute.InputContext
object and returning atf.data.Dataset
. -
ImplicitContainer<T>
replication_mode - an enum value of
tf.distribute.InputReplicationMode
. Only `PER_WORKER` is supported currently, which means there will be a single call to `input_fn` per worker. Replicas will dequeue from the localtf.data.Dataset
on their worker.
Returns
-
object
- An iterator object that should first be `.initialize()`-ed. It may then either be passed to `strategy.experimental_run()` or you can `iterator.get_next()` to get the next value to pass to `strategy.extended.call_for_each_replica()`.