Type Checkpoint
Namespace tensorflow.train
Parent AutoTrackable
Interfaces ICheckpoint
Groups trackable objects, saving and restoring them. `Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.compat.v1.train.Optimizer`
implementations,
tf.Variable
, `tf.keras.Layer` implementations, or
tf.keras.Model
implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints. Example usage when graph building:
Example usage with eager execution enabled:
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
etc.) with named edges, and this graph is used to match variables when
restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables when executing
eagerly. Prefer tf.train.Checkpoint
over `tf.compat.v1.train.Saver` for new
code. `Checkpoint` objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
identical to the name of the keyword argument for which it was created.
TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
dependencies on their variables (e.g. "kernel" and "bias" for
tf.keras.layers.Dense
). Inheriting from tf.keras.Model
makes managing
dependencies easy in user-defined classes, since `Model` hooks into attribute
assignment.
This `Model` has a dependency named "input_transform" on its `Dense` layer,
which in turn depends on its variables. As a result, saving an instance of
`Regress` using tf.train.Checkpoint
will also save all the variables created
by the `Dense` layer. When variables are assigned to multiple workers, each worker writes its own
section of the checkpoint. These sections are then merged/re-indexed to behave
as a single checkpoint. This avoids copying all variables to one worker, but
does require that all workers see a common filesystem. While tf.keras.Model.save_weights
and tf.train.Checkpoint.save
save in the
same format, note that the root of the resulting checkpoint is the object the
save method is attached to. This means saving a tf.keras.Model
using
`save_weights` and loading into a tf.train.Checkpoint
with a `Model`
attached (or vice versa) will not match the `Model`'s variables. See the
[guide to training
checkpoints](https://www.tensorflow.org/alpha/guide/checkpoints) for
details. Prefer tf.train.Checkpoint
over tf.keras.Model.save_weights
for
training checkpoints.
Show Example
import tensorflow as tf import os checkpoint_directory = "/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) train_op = optimizer.minimize(... ) status.assert_consumed() # Optional sanity checks. with tf.compat.v1.Session() as session: # Use the Session to restore variables, or initialize them if # tf.train.latest_checkpoint returned None. status.initialize_or_restore(session) for _ in range(num_training_steps): session.run(train_op) checkpoint.save(file_prefix=checkpoint_prefix)
Methods
Properties
Public instance methods
object restore(Byte[] save_path)
Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. When executing eagerly, either assigns values immediately if variables to
restore have been created already, or defers restoration until the variables
are created. Dependencies added after this call will be matched if they have
a corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added). When graph building, restoration ops are added to the graph but not run
immediately. To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
An exception will be raised if any Python objects in the dependency graph
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object. When graph building, `assert_consumed()` indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the `run_restore_ops()` method of the status object:
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
tf.train.Checkpoint.save
as soon as possible.
Parameters
-
Byte[]
save_path - The path to the checkpoint, as returned by `save` or
tf.train.latest_checkpoint
. If None (as when there is no latest checkpoint fortf.train.latest_checkpoint
to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables.
Returns
-
object
- A load status object, which can be used to make assertions about the status of a checkpoint restoration and run initialization/restore ops. The returned status object has the following methods: * `assert_consumed()`: Raises an exception if any variables/objects are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_existing_objects_matched()`: Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike `assert_consumed`, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a `tf.keras.Layer` object which has not yet been built, and so has not created any variables, will pass this assertion but fail `assert_consumed`. Useful when loading part of a larger checkpoint into a new Python program, e.g. a training checkpoint with a `tf.compat.v1.train.Optimizer` was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_nontrivial_match()`: Asserts that something aside from the root object was matched. This is a very weak assertion, but is useful for sanity checking in library code where objects may exist in the checkpoint which haven't been created in Python and some Python objects may not have a checkpointed value. * `expect_partial()`: Silence warnings about incomplete checkpoint restores. Warnings are otherwise printed for unused parts of the checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). * `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (variables are initialized or restored eagerly). * `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (restore operations are run eagerly). May only be called when `save_path` is not `None`.
Show Example
checkpoint = tf.train.Checkpoint(... ) checkpoint.restore(path).assert_consumed()
object restore(IGraphNodeBase save_path)
Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. When executing eagerly, either assigns values immediately if variables to
restore have been created already, or defers restoration until the variables
are created. Dependencies added after this call will be matched if they have
a corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added). When graph building, restoration ops are added to the graph but not run
immediately. To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
An exception will be raised if any Python objects in the dependency graph
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object. When graph building, `assert_consumed()` indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the `run_restore_ops()` method of the status object:
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
tf.train.Checkpoint.save
as soon as possible.
Parameters
-
IGraphNodeBase
save_path - The path to the checkpoint, as returned by `save` or
tf.train.latest_checkpoint
. If None (as when there is no latest checkpoint fortf.train.latest_checkpoint
to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables.
Returns
-
object
- A load status object, which can be used to make assertions about the status of a checkpoint restoration and run initialization/restore ops. The returned status object has the following methods: * `assert_consumed()`: Raises an exception if any variables/objects are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_existing_objects_matched()`: Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike `assert_consumed`, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a `tf.keras.Layer` object which has not yet been built, and so has not created any variables, will pass this assertion but fail `assert_consumed`. Useful when loading part of a larger checkpoint into a new Python program, e.g. a training checkpoint with a `tf.compat.v1.train.Optimizer` was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_nontrivial_match()`: Asserts that something aside from the root object was matched. This is a very weak assertion, but is useful for sanity checking in library code where objects may exist in the checkpoint which haven't been created in Python and some Python objects may not have a checkpointed value. * `expect_partial()`: Silence warnings about incomplete checkpoint restores. Warnings are otherwise printed for unused parts of the checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). * `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (variables are initialized or restored eagerly). * `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (restore operations are run eagerly). May only be called when `save_path` is not `None`.
Show Example
checkpoint = tf.train.Checkpoint(... ) checkpoint.restore(path).assert_consumed()
object restore(string save_path)
Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. When executing eagerly, either assigns values immediately if variables to
restore have been created already, or defers restoration until the variables
are created. Dependencies added after this call will be matched if they have
a corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added). When graph building, restoration ops are added to the graph but not run
immediately. To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
An exception will be raised if any Python objects in the dependency graph
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object. When graph building, `assert_consumed()` indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the `run_restore_ops()` method of the status object:
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
tf.train.Checkpoint.save
as soon as possible.
Parameters
-
string
save_path - The path to the checkpoint, as returned by `save` or
tf.train.latest_checkpoint
. If None (as when there is no latest checkpoint fortf.train.latest_checkpoint
to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables.
Returns
-
object
- A load status object, which can be used to make assertions about the status of a checkpoint restoration and run initialization/restore ops. The returned status object has the following methods: * `assert_consumed()`: Raises an exception if any variables/objects are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_existing_objects_matched()`: Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike `assert_consumed`, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a `tf.keras.Layer` object which has not yet been built, and so has not created any variables, will pass this assertion but fail `assert_consumed`. Useful when loading part of a larger checkpoint into a new Python program, e.g. a training checkpoint with a `tf.compat.v1.train.Optimizer` was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_nontrivial_match()`: Asserts that something aside from the root object was matched. This is a very weak assertion, but is useful for sanity checking in library code where objects may exist in the checkpoint which haven't been created in Python and some Python objects may not have a checkpointed value. * `expect_partial()`: Silence warnings about incomplete checkpoint restores. Warnings are otherwise printed for unused parts of the checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). * `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (variables are initialized or restored eagerly). * `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (restore operations are run eagerly). May only be called when `save_path` is not `None`.
Show Example
checkpoint = tf.train.Checkpoint(... ) checkpoint.restore(path).assert_consumed()
object restore_dyn(object save_path)
Restore a training checkpoint. Restores this `Checkpoint` and any objects it depends on. When executing eagerly, either assigns values immediately if variables to
restore have been created already, or defers restoration until the variables
are created. Dependencies added after this call will be matched if they have
a corresponding object in the checkpoint (the restore request will queue in
any trackable object waiting for the expected dependency to be added). When graph building, restoration ops are added to the graph but not run
immediately. To ensure that loading is complete and no more assignments will take place,
use the `assert_consumed()` method of the status object returned by
`restore`:
An exception will be raised if any Python objects in the dependency graph
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object. When graph building, `assert_consumed()` indicates that all of the restore
ops that will be created for this checkpoint have been created. They can be
run via the `run_restore_ops()` method of the status object:
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
tf.train.Checkpoint.save
as soon as possible.
Parameters
-
object
save_path - The path to the checkpoint, as returned by `save` or
tf.train.latest_checkpoint
. If None (as when there is no latest checkpoint fortf.train.latest_checkpoint
to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.compat.v1.train.Saver`, names are used to match variables.
Returns
-
object
- A load status object, which can be used to make assertions about the status of a checkpoint restoration and run initialization/restore ops. The returned status object has the following methods: * `assert_consumed()`: Raises an exception if any variables/objects are unmatched: either checkpointed values which don't have a matching Python object or Python objects in the dependency graph with no values in the checkpoint. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_existing_objects_matched()`: Raises an exception if any existing Python objects in the dependency graph are unmatched. Unlike `assert_consumed`, this assertion will pass if values in the checkpoint have no corresponding Python objects. For example a `tf.keras.Layer` object which has not yet been built, and so has not created any variables, will pass this assertion but fail `assert_consumed`. Useful when loading part of a larger checkpoint into a new Python program, e.g. a training checkpoint with a `tf.compat.v1.train.Optimizer` was saved but only the state required for inference is being loaded. This method returns the status object, and so may be chained with `initialize_or_restore` or `run_restore_ops`. * `assert_nontrivial_match()`: Asserts that something aside from the root object was matched. This is a very weak assertion, but is useful for sanity checking in library code where objects may exist in the checkpoint which haven't been created in Python and some Python objects may not have a checkpointed value. * `expect_partial()`: Silence warnings about incomplete checkpoint restores. Warnings are otherwise printed for unused parts of the checkpoint file or object when the `Checkpoint` object is deleted (often at program shutdown). * `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (variables are initialized or restored eagerly). * `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is explicitly specified, the default session is used. No effect when executing eagerly (restore operations are run eagerly). May only be called when `save_path` is not `None`.
Show Example
checkpoint = tf.train.Checkpoint(... ) checkpoint.restore(path).assert_consumed()
object save(Byte[] file_prefix, Nullable<bool> session)
Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.save()` is
called. `save` is a basic convenience wrapper around the `write` method,
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by
tf.train.latest_checkpoint
. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`
(tf.contrib.checkpoint.CheckpointManager
for example).
Parameters
-
Byte[]
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`.
-
Nullable<bool>
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint.
object save(IGraphNodeBase file_prefix, Nullable<bool> session)
Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.save()` is
called. `save` is a basic convenience wrapper around the `write` method,
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by
tf.train.latest_checkpoint
. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`
(tf.contrib.checkpoint.CheckpointManager
for example).
Parameters
-
IGraphNodeBase
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`.
-
Nullable<bool>
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint.
object save(string file_prefix, Nullable<bool> session)
Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.save()` is
called. `save` is a basic convenience wrapper around the `write` method,
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by
tf.train.latest_checkpoint
. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`
(tf.contrib.checkpoint.CheckpointManager
for example).
Parameters
-
string
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`.
-
Nullable<bool>
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint.
object save_dyn(object file_prefix, object session)
Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.save()` is
called. `save` is a basic convenience wrapper around the `write` method,
sequentially numbering checkpoints using `save_counter` and updating the
metadata used by
tf.train.latest_checkpoint
. More advanced checkpoint
management, for example garbage collection and custom numbering, may be
provided by other utilities which also wrap `write`
(tf.contrib.checkpoint.CheckpointManager
for example).
Parameters
-
object
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`.
-
object
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint.
object write(string file_prefix, bool session)
Writes a training checkpoint. The checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.write()` is
called. `write` does not number checkpoints, increment `save_counter`, or update the
metadata used by
tf.train.latest_checkpoint
. It is primarily intended for
use by higher level checkpoint management utilities. `save` provides a very
basic implementation of these features.
Parameters
-
string
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix).
-
bool
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint (i.e. `file_prefix`).
object write(string file_prefix, BaseSession session)
Writes a training checkpoint. The checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.write()` is
called. `write` does not number checkpoints, increment `save_counter`, or update the
metadata used by
tf.train.latest_checkpoint
. It is primarily intended for
use by higher level checkpoint management utilities. `save` provides a very
basic implementation of these features.
Parameters
-
string
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix).
-
BaseSession
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint (i.e. `file_prefix`).
object write_dyn(object file_prefix, object session)
Writes a training checkpoint. The checkpoint includes variables created by this object and any
trackable objects it depends on at the time `Checkpoint.write()` is
called. `write` does not number checkpoints, increment `save_counter`, or update the
metadata used by
tf.train.latest_checkpoint
. It is primarily intended for
use by higher level checkpoint management utilities. `save` provides a very
basic implementation of these features.
Parameters
-
object
file_prefix - A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix).
-
object
session - The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used.
Returns
-
object
- The full path to the checkpoint (i.e. `file_prefix`).
Public static methods
Checkpoint NewDyn(IDictionary<string, object> kwargs)
Group objects into a training checkpoint.
Parameters
-
IDictionary<string, object>
kwargs - Keyword arguments are set as attributes of this object, and are saved with the checkpoint. Values must be trackable objects.
Public properties
object PythonObject get;
NoDependency save_counter get;
An integer variable which starts at zero and is incremented on save. Used to number checkpoints.
object save_counter_dyn get;
An integer variable which starts at zero and is incremented on save. Used to number checkpoints.