This R6 class stores a text embedding model which can be used to tokenize, encode, decode, and embed raw texts. The object provides a unique interface for different text processing methods.
Objects of class TextEmbeddingModel
transform raw texts into numerical
representations which can be used for downstream tasks. For this aim objects of this class
allow to tokenize raw texts, to encode tokens to sequences of integers, and to decode sequences
of integers back to tokens.
last_training
('list()')
List for storing the history and the results of the last training. This
information will be overwritten if a new training is started.
new()
Method for creating a new text embedding model
TextEmbeddingModel$new(
model_name = NULL,
model_label = NULL,
model_version = NULL,
model_language = NULL,
method = NULL,
ml_framework = aifeducation_config$get_framework()$TextEmbeddingFramework,
max_length = 0,
chunks = 1,
overlap = 0,
emb_layer_min = "middle",
emb_layer_max = "2_3_layer",
emb_pool_type = "average",
model_dir,
bow_basic_text_rep,
bow_n_dim = 10,
bow_n_cluster = 100,
bow_max_iter = 500,
bow_max_iter_cluster = 500,
bow_cr_criterion = 1e-08,
bow_learning_rate = 1e-08,
trace = FALSE
)
model_name
string
containing the name of the new model.
model_label
string
containing the label/title of the new model.
model_version
string
version of the model.
model_language
string
containing the language which the model
represents (e.g., English).
method
string
determining the kind of embedding model. Currently
the following models are supported:
method="bert"
for Bidirectional Encoder Representations from Transformers (BERT),
method="roberta"
for A Robustly Optimized BERT Pretraining Approach (RoBERTa),
method="longformer"
for Long-Document Transformer,
method="funnel"
for Funnel-Transformer,
method="deberta_v2"
for Decoding-enhanced BERT with Disentangled Attention (DeBERTa V2),
method="glove"
for
GlobalVector Clusters, and
method="lda"
for topic modeling. See
details for more information.
ml_framework
string
Framework to use for the model.
ml_framework="tensorflow"
for 'tensorflow' and ml_framework="pytorch"
for 'pytorch'. Only relevant for transformer models.
max_length
int
determining the maximum length of token
sequences used in transformer models. Not relevant for the other methods.
chunks
int
Maximum number of chunks. Only relevant for
transformer models.
overlap
int
determining the number of tokens which should be added
at the beginning of the next chunk. Only relevant for BERT models.
emb_layer_min
int
or string
determining the first layer to be included
in the creation of embeddings. An integer correspondents to the layer number. The first
layer has the number 1. Instead of an integer the following strings are possible:
"start"
for the first layer, "middle"
for the middle layer,
"2_3_layer"
for the layer two-third layer, and "last"
for the last layer.
emb_layer_max
int
or string
determining the last layer to be included
in the creation of embeddings. An integer correspondents to the layer number. The first
layer has the number 1. Instead of an integer the following strings are possible:
"start"
for the first layer, "middle"
for the middle layer,
"2_3_layer"
for the layer two-third layer, and "last"
for the last layer.
emb_pool_type
string
determining the method for pooling the token embeddings
within each layer. If "cls"
only the embedding of the CLS token is used. If
"average"
the token embedding of all tokens are averaged (excluding padding tokens).
model_dir
string
path to the directory where the
BERT model is stored.
bow_basic_text_rep
object of class basic_text_rep
created via
the function bow_pp_create_basic_text_rep. Only relevant for method="glove_cluster"
and method="lda"
.
bow_n_dim
int
Number of dimensions of the GlobalVector or
number of topics for LDA.
bow_n_cluster
int
Number of clusters created on the basis
of GlobalVectors. Parameter is not relevant for method="lda"
and
method="bert"
bow_max_iter
int
Maximum number of iterations for fitting
GlobalVectors and Topic Models.
bow_max_iter_cluster
int
Maximum number of iterations for
fitting cluster if method="glove"
.
bow_cr_criterion
double
convergence criterion for GlobalVectors.
bow_learning_rate
double
initial learning rate for GlobalVectors.
trace
bool
TRUE
prints information about the progress.
FALSE
does not.
method: In the case of method="bert"
, method="roberta"
, and method="longformer"
,
a pretrained transformer model
must be supplied via model_dir
. For method="glove"
and method="lda"
a new model will be created based on the data provided
via bow_basic_text_rep
. The original algorithm for GlobalVectors provides
only word embeddings, not text embeddings. To achieve text embeddings the words
are clustered based on their word embeddings with kmeans.
Returns an object of class TextEmbeddingModel.
load_model()
Method for loading a transformers model into R.
TextEmbeddingModel$load_model(model_dir, ml_framework = "auto")
model_dir
string
containing the path to the relevant
model directory.
ml_framework
string
Determines the machine learning framework
for using the model. Possible are ml_framework="pytorch"
for 'pytorch',
ml_framework="tensorflow"
for 'tensorflow', and ml_framework="auto"
.
Function does not return a value. It is used for loading a saved transformer model into the R interface.
save_model()
Method for saving a transformer model on disk.Relevant only for transformer models.
TextEmbeddingModel$save_model(model_dir, save_format = "default")
model_dir
string
containing the path to the relevant
model directory.
save_format
Format for saving the model. For 'tensorflow'/'keras' models
"h5"
for HDF5.
For 'pytorch' models "safetensors"
for 'safetensors' or
"pt"
for 'pytorch' via pickle.
Use "default"
for the standard format. This is h5 for
'tensorflow'/'keras' models and safetensors for 'pytorch' models.
Function does not return a value. It is used for saving a transformer model to disk.
encode()
Method for encoding words of raw texts into integers.
TextEmbeddingModel$encode(
raw_text,
token_encodings_only = FALSE,
to_int = TRUE,
trace = FALSE
)
raw_text
vector
containing the raw texts.
token_encodings_only
bool
If TRUE
, only the token
encodings are returned. If FALSE
, the complete encoding is returned
which is important for BERT models.
to_int
bool
If TRUE
the integer ids of the tokens are
returned. If FALSE
the tokens are returned. Argument only applies
for transformer models and if token_encodings_only==TRUE
.
trace
bool
If TRUE
, information of the progress
is printed. FALSE
if not requested.
list
containing the integer sequences of the raw texts with
special tokens.
decode()
Method for decoding a sequence of integers into tokens
TextEmbeddingModel$decode(int_seqence, to_token = FALSE)
int_seqence
list
containing the integer sequences which
should be transformed to tokens or plain text.
to_token
bool
If FALSE
a plain text is returned.
if TRUE
a sequence of tokens is returned. Argument only relevant
if the model is based on a transformer.
list
of token sequences
get_special_tokens()
Method for receiving the special tokens of the model
TextEmbeddingModel$get_special_tokens()
Returns a matrix
containing the special tokens in the rows
and their type, token, and id in the columns.
embed()
Method for creating text embeddings from raw texts
In the case of using a GPU and running out of memory reduce the batch size or restart R and switch to use cpu only via set_config_cpu_only.
TextEmbeddingModel$embed(
raw_text = NULL,
doc_id = NULL,
batch_size = 8,
trace = FALSE
)
raw_text
vector
containing the raw texts.
doc_id
vector
containing the corresponding IDs for every text.
batch_size
int
determining the maximal size of every batch.
trace
bool
TRUE
, if information about the progression
should be printed on console.
Method returns a R6 object of class EmbeddedText. This object
contains the embeddings as a data.frame
and information about the
model creating the embeddings.
fill_mask()
Method for calculating tokens behind mask tokens.
TextEmbeddingModel$fill_mask(text, n_solutions = 5)
text
string
Text containing mask tokens.
n_solutions
int
Number estimated tokens for every mask.
Returns a list
containing a data.frame
for every
mask. The data.frame
contains the solutions in the rows and reports
the score, token id, and token string in the columns.
set_publication_info()
Method for setting the bibliographic information of the model.
TextEmbeddingModel$set_publication_info(type, authors, citation, url = NULL)
type
string
Type of information which should be changed/added.
type="developer"
, and type="modifier"
are possible.
authors
List of people.
citation
string
Citation in free text.
url
string
Corresponding URL if applicable.
Function does not return a value. It is used to set the private members for publication information of the model.
get_publication_info()
Method for getting the bibliographic information of the model.
TextEmbeddingModel$get_publication_info()
list
of bibliographic information.
set_software_license()
Method for setting the license of the model
TextEmbeddingModel$set_software_license(license = "GPL-3")
license
string
containing the abbreviation of the license or
the license text.
Function does not return a value. It is used for setting the private member for the software license of the model.
get_software_license()
Method for requesting the license of the model
TextEmbeddingModel$get_software_license()
string
License of the model
set_documentation_license()
Method for setting the license of models' documentation.
TextEmbeddingModel$set_documentation_license(license = "CC BY-SA")
license
string
containing the abbreviation of the license or
the license text.
Function does not return a value. It is used to set the private member for the documentation license of the model.
get_documentation_license()
Method for getting the license of the models' documentation.
TextEmbeddingModel$get_documentation_license()
license
string
containing the abbreviation of the license or
the license text.
set_model_description()
Method for setting a description of the model
TextEmbeddingModel$set_model_description(
eng = NULL,
native = NULL,
abstract_eng = NULL,
abstract_native = NULL,
keywords_eng = NULL,
keywords_native = NULL
)
eng
string
A text describing the training of the classifier,
its theoretical and empirical background, and the different output labels
in English.
native
string
A text describing the training of the classifier,
its theoretical and empirical background, and the different output labels
in the native language of the model.
abstract_eng
string
A text providing a summary of the description
in English.
abstract_native
string
A text providing a summary of the description
in the native language of the classifier.
keywords_eng
vector
of keywords in English.
keywords_native
vector
of keywords in the native language of the classifier.
Function does not return a value. It is used to set the private members for the description of the model.
get_model_description()
Method for requesting the model description.
TextEmbeddingModel$get_model_description()
list
with the description of the model in English
and the native language.
get_model_info()
Method for requesting the model information
TextEmbeddingModel$get_model_info()
list
of all relevant model information
get_package_versions()
Method for requesting a summary of the R and python packages' versions used for creating the classifier.
TextEmbeddingModel$get_package_versions()
Returns a list
containing the versions of the relevant
R and python packages.
get_basic_components()
Method for requesting the part of interface's configuration that is necessary for all models.
TextEmbeddingModel$get_basic_components()
Returns a list
.
get_bow_components()
Method for requesting the part of interface's configuration that is necessary bag-of-words models.
TextEmbeddingModel$get_bow_components()
Returns a list
.
get_transformer_components()
Method for requesting the part of interface's configuration that is necessary for transformer models.
TextEmbeddingModel$get_transformer_components()
Returns a list
.
get_sustainability_data()
Method for requesting a log of tracked energy consumption during training and an estimate of the resulting CO2 equivalents in kg.
TextEmbeddingModel$get_sustainability_data()
Returns a matrix
containing the tracked energy consumption,
CO2 equivalents in kg, information on the tracker used, and technical
information on the training infrastructure for every training run.
get_ml_framework()
Method for requesting the machine learning framework used for the classifier.
TextEmbeddingModel$get_ml_framework()
Returns a string
describing the machine learning framework used
for the classifier
clone()
The objects of this class are cloneable with this method.
TextEmbeddingModel$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Text Embedding:
EmbeddedText
,
combine_embeddings()