Type StrategyExtended
Namespace tensorflow.distribute
Parent StrategyExtended
Interfaces IStrategyExtended
Additional APIs for algorithms that need to be distribution-aware. Note: For most usage of
tf.distribute.Strategy
, there should be no need to
call these methods, since TensorFlow libraries (such as optimizers) already
call these methods when needed on your behalf. Lower-level concepts: * Wrapped values: In order to represent values parallel across devices
(either replicas or the devices associated with a particular value), we
wrap them in a "PerReplica" or "Mirrored" object that contains a map
from replica id to values. "PerReplica" is used when the value may be
different across replicas, and "Mirrored" when the value are the same.
* Unwrapping and merging: Consider calling a function `fn` on multiple
replicas, like `experimental_run_v2(fn, args=[w])` with an
argument `w` that is a wrapped value. This means `w` will have a map taking
replica id `0` to `w0`, replica id `11` to `w1`, etc.
`experimental_run_v2()` unwraps `w` before calling `fn`, so
it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return
values from `fn()`, which can possibly result in wrapped values. For
example, let's say `fn()` returns a tuple with three components: `(x, a,
v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
is the same object `x` from every replica, then the first component of the
merged result will also be `x`. If the second component is different (`a`,
`b`,...) from each replica, then the merged value will have a wrapped map
from replica device to the different values. If the third component is the
members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
then the merged result will be that mirrored variable (`v`).
* Worker devices vs. parameter devices: Most replica computations will
happen on worker devices. Since we don't yet support model
parallelism, there will be one worker device per replica. When using
parameter servers or central storage, the set of devices holding
variables may be different, otherwise the parameter devices might
match the worker devices. *Replica context vs. Cross-replica context* _replica context_ is when we are in some function that is being called once
for each replica. Otherwise we are in cross-replica context, which is
useful for calling tf.distribute.Strategy
methods which operate across the
replicas (like `reduce_to()`). By default you start in a replica context
(the "default single replica context") and then some methods can switch you
back and forth. There is a third mode you can be in called _update context_
used when updating variables. * tf.distribute.Strategy.scope
: enters cross-replica context when
no other strategy is in scope.
* tf.distribute.Strategy.experimental_run_v2
: calls a function in
replica context.
* tf.distribute.ReplicaContext.merge_call
: transitions from replica
context to cross-replica context.
* tf.distribute.StrategyExtended.update
: calls a function in an update
context from a cross-replica context. In a replica context, you may freely read the values of variables, but
you may only update their value if they specify a way to aggregate the
update using the `aggregation` parameter in the variable's constructor.
In a cross-replica context, you may read or write variables (writes may
need to be broadcast to all copies of the variable if it is mirrored). *Sync on read variables* In some cases, such as a metric, we want to accumulate a bunch of updates on
each replica independently and only aggregate when reading. This can be a big
performance win when the value is read only rarely (maybe the value is only
read at the end of an epoch or when checkpointing). These are variables
created by passing `synchronization=ON_READ` to the variable's constructor
(and some value for `aggregation`). The strategy may choose to put the variable on multiple devices, like mirrored
variables, but unlike mirrored variables we don't synchronize the updates to
them to make sure they have the same value. Instead, the synchronization is
performed when reading in cross-replica context. In a replica context, reads
and writes are performed on the local copy (we allow reads so you can write
code like `v = 0.9*v + 0.1*update`). We don't allow operations like
`v.assign_add` in a cross-replica context for sync on read variables; right
now we don't have a use case for such updates and depending on the aggregation
mode such updates may not be sensible. *Locality* Depending on how a value is produced, it will have a type that will determine
how it may be used. "Per-replica" values exist on the worker devices, with a different value for
each replica. They are produced by iterating through a "distributed `Dataset`"
returned by tf.distribute.Strategy.experimental_distribute_dataset
and
tf.distribute.Strategy.experimental_distribute_datasets_from_function
. They
are also the typical result returned by
tf.distribute.Strategy.experimental_run_v2
. You typically can't use a
per-replica value directly in a cross-replica context, without first resolving
how to aggregate the values across replicas, for instance by using
tf.distribute.Strategy.reduce
. "Mirrored" values are like per-replica values, except we know that the value
on all replicas are the same. We can safely read a mirrored value in a
cross-replica context by using the value on any replica. You can convert
a per-replica value into a mirrored value by using
tf.distribute.ReplicaContext.all_reduce
. Values can also have the same locality as a variable, which is a mirrored
value but residing on the same devices as the variable (as opposed to the
compute devices). Such values may be passed to a call to
tf.distribute.StrategyExtended.update
to update the value of a variable.
You may use tf.distribute.StrategyExtended.colocate_vars_with
to give a
variable the same locality as another variable. This is useful, for example,
for "slot" variables used by an optimizer for keeping track of statistics
used to update a primary/model variable. You may convert a per-replica
value to a variable's locality by using
tf.distribute.StrategyExtended.reduce_to
or
tf.distribute.StrategyExtended.batch_reduce_to
. In addition to slot variables which should be colocated with their primary
variables, optimizers also define non-slot variables. These can be things like
"number of step updates performed" or "beta1^t" and "beta2^t". Each strategy
has some policy for which devices those variables should be copied too, called
the "non-slot devices" (some subset of the parameter devices). We require that
all non-slot variables are allocated on the same device, or mirrored across
the same set of devices. You can use
tf.distribute.StrategyExtended.non_slot_devices
to pick a consistent set of
devices to pass to both tf.distribute.StrategyExtended.colocate_vars_with
and tf.distribute.StrategyExtended.update_non_slot
. *How to update a variable* The standard pattern for updating variables is to: 1. In your function passed to tf.distribute.Strategy.experimental_run_v2
,
compute a list of (update, variable) pairs. For example, the update might
be a the gradient of the loss with respect to the variable.
2. Switch to cross-replica mode by calling
`tf.distribute.get_replica_context().merge_call()` with the updates and
variables as arguments.
3. Call
`tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
(for one variable) or tf.distribute.StrategyExtended.batch_reduce_to
(for a list of variables) to sum the updates.
and broadcast the result to the variable's devices.
4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
its value. Steps 2 through 4 are done automatically by class
tf.keras.optimizers.Optimizer
if you call its
tf.keras.optimizers.Optimizer.apply_gradients
method in a replica context.
They are also done automatically if you call an `assign*` method on a (non
sync-on-read) variable that was constructed with an aggregation method (which
is used to determine the reduction used in step 3). *Distribute-aware layers* Layers are generally called in a replica context, except when defining a
functional model. tf.distribute.in_cross_replica_context
will let you
determine which case you are in. If in a replica context,
the tf.distribute.get_replica_context
function will return a
tf.distribute.ReplicaContext
object. The `ReplicaContext` object has an
`all_reduce` method for aggregating across all replicas. Alternatively, you
can update variables following steps 2-4 above. Note: For new tf.distribute.Strategy
implementations, please put all logic
in a subclass of tf.distribute.StrategyExtended
. The only code needed for
the tf.distribute.Strategy
subclass is for instantiating your subclass of
tf.distribute.StrategyExtended
in the `__init__` method.
Methods
- batch_reduce_to
- batch_reduce_to_dyn
- broadcast_to
- broadcast_to_dyn
- call_for_each_replica
- call_for_each_replica
- call_for_each_replica
- call_for_each_replica_dyn
- experimental_make_numpy_dataset_dyn
- experimental_run_steps_on_iterator
- experimental_run_steps_on_iterator_dyn
- non_slot_devices
- non_slot_devices_dyn
- read_var
- read_var_dyn
- reduce_to
- reduce_to
- reduce_to_dyn
- update
- update_dyn
- update_non_slot
- update_non_slot_dyn
- value_container
- value_container_dyn
- variable_created_in_scope
- variable_created_in_scope_dyn
Properties
- experimental_between_graph
- experimental_between_graph_dyn
- experimental_require_static_shapes
- experimental_require_static_shapes_dyn
- experimental_should_init
- experimental_should_init_dyn
- parameter_devices
- parameter_devices_dyn
- PythonObject
- should_checkpoint
- should_checkpoint_dyn
- should_save_summary
- should_save_summary_dyn
- worker_devices
- worker_devices_dyn
Public instance methods
IList<object> batch_reduce_to(object reduce_op, object value_destination_pairs)
object batch_reduce_to_dyn(object reduce_op, object value_destination_pairs)
object broadcast_to(object tensor, object destinations)
Mirror a tensor on one device to all worker devices.
Parameters
-
object
tensor - A Tensor value to broadcast.
-
object
destinations - A mirrored variable or device string specifying the destination devices to copy `tensor` to.
Returns
-
object
- A value mirrored to `destinations` devices.
object broadcast_to_dyn(object tensor, object destinations)
Mirror a tensor on one device to all worker devices.
Parameters
-
object
tensor - A Tensor value to broadcast.
-
object
destinations - A mirrored variable or device string specifying the destination devices to copy `tensor` to.
Returns
-
object
- A value mirrored to `destinations` devices.
object call_for_each_replica(Template fn, ImplicitContainer<T> args, IDictionary<string, object> kwargs)
Run `fn` once per replica. `fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`. `merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
Parameters
-
Template
fn - function to run (will be run once per replica).
-
ImplicitContainer<T>
args - Tuple or list with positional arguments for `fn`.
-
IDictionary<string, object>
kwargs - Dict with keyword arguments for `fn`.
Returns
-
object
- Merged return value of `fn` across all replicas.
Show Example
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.experimental_run_v2(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
object call_for_each_replica(TFDecorator fn, ImplicitContainer<T> args, IDictionary<string, object> kwargs)
Run `fn` once per replica. `fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`. `merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
Parameters
-
TFDecorator
fn - function to run (will be run once per replica).
-
ImplicitContainer<T>
args - Tuple or list with positional arguments for `fn`.
-
IDictionary<string, object>
kwargs - Dict with keyword arguments for `fn`.
Returns
-
object
- Merged return value of `fn` across all replicas.
Show Example
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.experimental_run_v2(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
object call_for_each_replica(object fn, ImplicitContainer<T> args, IDictionary<string, object> kwargs)
Run `fn` once per replica. `fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`. `merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
Parameters
-
object
fn - function to run (will be run once per replica).
-
ImplicitContainer<T>
args - Tuple or list with positional arguments for `fn`.
-
IDictionary<string, object>
kwargs - Dict with keyword arguments for `fn`.
Returns
-
object
- Merged return value of `fn` across all replicas.
Show Example
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.experimental_run_v2(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
object call_for_each_replica_dyn(object fn, ImplicitContainer<T> args, object kwargs)
Run `fn` once per replica. `fn` may call `tf.get_replica_context()` to access methods such as
`replica_id_in_sync_group` and `merge_call()`. `merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a `merge_call()` call. After that the
`merge_fn`-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
`fn` is complete or encounters another `merge_call()`. Example:
Parameters
-
object
fn - function to run (will be run once per replica).
-
ImplicitContainer<T>
args - Tuple or list with positional arguments for `fn`.
-
object
kwargs - Dict with keyword arguments for `fn`.
Returns
-
object
- Merged return value of `fn` across all replicas.
Show Example
# Called once in "cross-replica" context. def merge_fn(distribution, three_plus_replica_id): # sum the values across replicas return sum(distribution.experimental_local_results(three_plus_replica_id)) # Called once per replica in `distribution`, in a "replica" context. def fn(three): replica_ctx = tf.get_replica_context() v = three + replica_ctx.replica_id_in_sync_group # Computes the sum of the `v` values across all replicas. s = replica_ctx.merge_call(merge_fn, args=(v,)) return s + v with distribution.scope(): # in "cross-replica" context ... merged_results = distribution.experimental_run_v2(fn, args=[3]) # merged_results has the values from every replica execution of `fn`. # This statement prints a list: print(distribution.experimental_local_results(merged_results))
object experimental_make_numpy_dataset_dyn(object numpy_input, object session)
Makes a dataset for input provided via a numpy array. This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Parameters
-
object
numpy_input - A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked, as
that is normal
tf.data.Dataset
behavior. -
object
session - (TensorFlow v1.x graph execution only) A session used for initialization.
Returns
-
object
- A
tf.data.Dataset
representing `numpy_input`.
MultiStepContext experimental_run_steps_on_iterator(object fn, object iterator, int iterations, object initial_loop_values)
DEPRECATED: please use `experimental_run_v2` instead. Run `fn` with input from `iterator` for `iterations` times. This method can be used to run a step function for training a number of
times using input from a dataset.
Parameters
-
object
fn - function to run using this distribution strategy. The function must have the following signature: `def fn(context, inputs)`. `context` is an instance of `MultiStepContext` that will be passed when `fn` is run. `context` can be used to specify the outputs to be returned from `fn` by calling `context.set_last_step_output`. It can also be used to capture non tensor outputs by `context.set_non_tensor_output`. See `MultiStepContext` documentation for more information. `inputs` will have same type/structure as `iterator.get_next()`. Typically, `fn` will use `call_for_each_replica` method of the strategy to distribute the computation over multiple replicas.
-
object
iterator - Iterator of a dataset that represents the input for `fn`. The caller is responsible for initializing the iterator as needed.
-
int
iterations - (Optional) Number of iterations that `fn` should be run. Defaults to 1.
-
object
initial_loop_values - (Optional) Initial values to be passed into the loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove initial_loop_values argument when we have a mechanism to infer the outputs of `fn`.
Returns
-
MultiStepContext
- Returns the `MultiStepContext` object which has the following properties, among other things: - run_op: An op that runs `fn` `iterations` times. - last_step_outputs: A dictionary containing tensors set using `context.set_last_step_output`. Evaluating this returns the value of the tensors after the last iteration. - non_tensor_outputs: A dictionatry containing anything that was set by `fn` by calling `context.set_non_tensor_output`.
object experimental_run_steps_on_iterator_dyn(object fn, object iterator, ImplicitContainer<T> iterations, object initial_loop_values)
DEPRECATED: please use `experimental_run_v2` instead. Run `fn` with input from `iterator` for `iterations` times. This method can be used to run a step function for training a number of
times using input from a dataset.
Parameters
-
object
fn - function to run using this distribution strategy. The function must have the following signature: `def fn(context, inputs)`. `context` is an instance of `MultiStepContext` that will be passed when `fn` is run. `context` can be used to specify the outputs to be returned from `fn` by calling `context.set_last_step_output`. It can also be used to capture non tensor outputs by `context.set_non_tensor_output`. See `MultiStepContext` documentation for more information. `inputs` will have same type/structure as `iterator.get_next()`. Typically, `fn` will use `call_for_each_replica` method of the strategy to distribute the computation over multiple replicas.
-
object
iterator - Iterator of a dataset that represents the input for `fn`. The caller is responsible for initializing the iterator as needed.
-
ImplicitContainer<T>
iterations - (Optional) Number of iterations that `fn` should be run. Defaults to 1.
-
object
initial_loop_values - (Optional) Initial values to be passed into the loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove initial_loop_values argument when we have a mechanism to infer the outputs of `fn`.
Returns
-
object
- Returns the `MultiStepContext` object which has the following properties, among other things: - run_op: An op that runs `fn` `iterations` times. - last_step_outputs: A dictionary containing tensors set using `context.set_last_step_output`. Evaluating this returns the value of the tensors after the last iteration. - non_tensor_outputs: A dictionatry containing anything that was set by `fn` by calling `context.set_non_tensor_output`.
object non_slot_devices(object var_list)
object non_slot_devices_dyn(object var_list)
Tensor read_var(object v)
Reads the value of a variable. Returns the aggregate value of a replica-local variable, or the
(read-only) value of any other variable.
Parameters
-
object
v - A variable allocated within the scope of this
tf.distribute.Strategy
.
Returns
-
Tensor
- A tensor representing the value of `v`, aggregated across replicas if necessary.
object read_var_dyn(object v)
Reads the value of a variable. Returns the aggregate value of a replica-local variable, or the
(read-only) value of any other variable.
Parameters
-
object
v - A variable allocated within the scope of this
tf.distribute.Strategy
.
Returns
-
object
- A tensor representing the value of `v`, aggregated across replicas if necessary.
object reduce_to(ReduceOp reduce_op, object value, object destinations)
object reduce_to(object reduce_op, object value, object destinations)
object reduce_to_dyn(object reduce_op, object value, object destinations)
object update(object var, object fn, ValueTuple<object> args, IDictionary<string, object> kwargs, bool group)
object update_dyn(object var, object fn, ImplicitContainer<T> args, object kwargs, ImplicitContainer<T> group)
object update_non_slot(object colocate_with, object fn, ImplicitContainer<T> args, IDictionary<string, object> kwargs, bool group)
object update_non_slot_dyn(object colocate_with, object fn, ImplicitContainer<T> args, object kwargs, ImplicitContainer<T> group)
object value_container(object value)
object value_container_dyn(object value)
bool variable_created_in_scope(object v)
object variable_created_in_scope_dyn(object v)
Public properties
bool experimental_between_graph get;
Whether the strategy uses between-graph replication or not. This is expected to return a constant value that will not be changed
throughout its life cycle.
object experimental_between_graph_dyn get;
Whether the strategy uses between-graph replication or not. This is expected to return a constant value that will not be changed
throughout its life cycle.
bool experimental_require_static_shapes get;
object experimental_require_static_shapes_dyn get;
bool experimental_should_init get;
Whether initialization is needed.
object experimental_should_init_dyn get;
Whether initialization is needed.
object parameter_devices get;
object parameter_devices_dyn get;
object PythonObject get;
bool should_checkpoint get;
Whether checkpointing is needed.
object should_checkpoint_dyn get;
Whether checkpointing is needed.
bool should_save_summary get;
Whether saving summaries is needed.
object should_save_summary_dyn get;
Whether saving summaries is needed.