Learn R Programming

tfestimators (version 1.9.2)

estimator_spec: Define an Estimator Specification

Description

Define the estimator specification, used as part of the model_fn defined with custom estimators created by estimator(). See estimator() for more details.

Usage

estimator_spec(
  mode,
  predictions = NULL,
  loss = NULL,
  train_op = NULL,
  eval_metric_ops = NULL,
  training_hooks = NULL,
  evaluation_hooks = NULL,
  prediction_hooks = NULL,
  training_chief_hooks = NULL,
  ...
)

Arguments

mode

A key that specifies whether we are performing training ("train"), evaluation ("eval"), or prediction ("infer"). These values can also be accessed through the mode_keys() object.

predictions

The prediction tensor(s).

loss

The training loss tensor. Must be either scalar, or with shape c(1).

train_op

The training operation -- typically, a call to optimizer$minimize(...), depending on the type of optimizer used during training.

eval_metric_ops

A list of metrics to be computed as part of evaluation. This should be a named list, mapping metric names (e.g. "rmse") to the operation that computes the associated metric (e.g. tf$metrics$root_mean_squared_error(...)). These metric operations should be evaluated without any impact on state (typically is a pure computation results based on variables). For example, it should not trigger the update ops or requires any input fetching.

training_hooks

(Available since TensorFlow v1.4) A list of session run hooks to run on all workers during training.

evaluation_hooks

(Available since TensorFlow v1.4) A list of session run hooks to run during evaluation.

prediction_hooks

(Available since TensorFlow v1.7) A list of session run hooks to run during prediciton.

training_chief_hooks

(Available since TensorFlow v1.4) A list of session run hooks to run on chief worker during training.

...

Other optional (named) arguments, to be passed to the EstimatorSpec constructor.

See Also

Other custom estimator methods: estimator(), evaluate.tf_estimator(), export_savedmodel.tf_estimator(), predict.tf_estimator(), train.tf_estimator()