Type tf.keras.estimator
Namespace tensorflow
Public static methods
object model_to_estimator(object keras_model, object keras_model_path, object custom_objects, object model_dir, object config, string checkpoint_format)
Constructs an `Estimator` instance from given keras model. For usage example, please see:
[Creating estimators from Keras
Models](https://tensorflow.org/guide/estimators#model_to_estimator). __Sample Weights__
Estimators returned by `model_to_estimator` are configured to handle sample
weights (similar to `keras_model.fit(x, y, sample_weights)`). To pass sample
weights when training or evaluating the Estimator, the first item returned by
the input function should be a dictionary with keys `features` and
`sample_weights`. Example below: ```
keras_model = tf.keras.Model(...)
keras_model.compile(...) estimator = tf.keras.estimator.model_to_estimator(keras_model) def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features': features, 'sample_weights': sample_weights},
targets)) estimator.train(input_fn, steps=1)
```
Parameters
-
object
keras_model - A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`.
-
object
keras_model_path - Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`.
-
object
custom_objects - Dictionary for custom objects.
-
object
model_dir - Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc.
-
object
config - `RunConfig` to config `Estimator`.
-
string
checkpoint_format - Sets the format of the checkpoint saved by the estimator
when training. May be `saver` or `checkpoint`, depending on whether to
save checkpoints from
tf.train.Saver
ortf.train.Checkpoint
. This argument currently defaults to `saver`. When 2.0 is released, the default will be `checkpoint`. Estimators use name-basedtf.train.Saver
checkpoints, while Keras models use object-based checkpoints fromtf.train.Checkpoint
. Currently, saving object-based checkpoints from `model_to_estimator` is only supported by Functional and Sequential models.
Returns
-
object
- An Estimator from given keras model.
object model_to_estimator_dyn(object keras_model, object keras_model_path, object custom_objects, object model_dir, object config, ImplicitContainer<T> checkpoint_format)
Constructs an `Estimator` instance from given keras model. For usage example, please see:
[Creating estimators from Keras
Models](https://tensorflow.org/guide/estimators#model_to_estimator). __Sample Weights__
Estimators returned by `model_to_estimator` are configured to handle sample
weights (similar to `keras_model.fit(x, y, sample_weights)`). To pass sample
weights when training or evaluating the Estimator, the first item returned by
the input function should be a dictionary with keys `features` and
`sample_weights`. Example below: ```
keras_model = tf.keras.Model(...)
keras_model.compile(...) estimator = tf.keras.estimator.model_to_estimator(keras_model) def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features': features, 'sample_weights': sample_weights},
targets)) estimator.train(input_fn, steps=1)
```
Parameters
-
object
keras_model - A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`.
-
object
keras_model_path - Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`.
-
object
custom_objects - Dictionary for custom objects.
-
object
model_dir - Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc.
-
object
config - `RunConfig` to config `Estimator`.
-
ImplicitContainer<T>
checkpoint_format - Sets the format of the checkpoint saved by the estimator
when training. May be `saver` or `checkpoint`, depending on whether to
save checkpoints from
tf.train.Saver
ortf.train.Checkpoint
. This argument currently defaults to `saver`. When 2.0 is released, the default will be `checkpoint`. Estimators use name-basedtf.train.Saver
checkpoints, while Keras models use object-based checkpoints fromtf.train.Checkpoint
. Currently, saving object-based checkpoints from `model_to_estimator` is only supported by Functional and Sequential models.
Returns
-
object
- An Estimator from given keras model.