tfestimators (version 1.9.1)

column_embedding: Construct a Dense Column

Description

Use this when your inputs are sparse, but you want to convert them to a dense representation (e.g., to feed to a DNN). Inputs must be a categorical column created by any of the column_categorical_*() functions.

Usage

column_embedding(categorical_column, dimension, combiner = "mean",
  initializer = NULL, ckpt_to_load_from = NULL,
  tensor_name_in_ckpt = NULL, max_norm = NULL, trainable = TRUE)

Arguments

categorical_column

A categorical column created by a column_categorical_*() function. This column produces the sparse IDs that are inputs to the embedding lookup.

dimension

A positive integer, specifying dimension of the embedding.

combiner

A string specifying how to reduce if there are multiple entries in a single row. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. "sqrtn"' often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column.

initializer

A variable initializer function to be used in embedding variable initialization. If not specified, defaults to tf$truncated_normal_initializer with mean 0.0 and standard deviation 1 / sqrt(dimension).

ckpt_to_load_from

String representing checkpoint name/pattern from which to restore column weights. Required if tensor_name_in_ckpt is not NULL.

tensor_name_in_ckpt

Name of the Tensor in ckpt_to_load_from from which to restore the column weights. Required if ckpt_to_load_from is not NULL.

max_norm

If not NULL, embedding values are l2-normalized to this value.

trainable

Whether or not the embedding is trainable. Default is TRUE.

Value

A dense column that converts from sparse input.

Raises

  • ValueError: if dimension not > 0.

  • ValueError: if exactly one of ckpt_to_load_from and tensor_name_in_ckpt is specified.

  • ValueError: if initializer is specified and is not callable.

See Also

Other feature column constructors: column_bucketized, column_categorical_weighted, column_categorical_with_hash_bucket, column_categorical_with_identity, column_categorical_with_vocabulary_file, column_categorical_with_vocabulary_list, column_crossed, column_numeric, input_layer