Base class for classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings generated with a TextEmbeddingModel.
Objects of this class containing fields and methods used in several other classes in 'AI for Education'.
This class is not designed for a direct application and should only be used by developers.
A new object of this class.
aifeducation::AIFEBaseModel -> aifeducation::ModelsBasedOnTextEmbeddings -> ClassifiersBasedOnTextEmbeddings
feature_extractor('list()')
List for storing information and objects about the feature_extractor.
reliability('list()')
List for storing central reliability measures of the last training.
reliability$test_metric: Array containing the reliability measures for the test data for
every fold and step (in case of pseudo-labeling).
reliability$test_metric_mean: Array containing the reliability measures for the test data.
The values represent the mean values for every fold.
reliability$raw_iota_objects: List containing all iota_object generated with the package iotarelr
for every fold at the end of the last training for the test data.
reliability$raw_iota_objects$iota_objects_end: List of objects with class iotarelr_iota2 containing the
estimated iota reliability of the second generation for the final model for every fold for the test data.
reliability$raw_iota_objects$iota_objects_end_free: List of objects with class iotarelr_iota2 containing
the estimated iota reliability of the second generation for the final model for every fold for the test data.
Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the
assumption of weak superiority.
reliability$iota_object_end: Object of class iotarelr_iota2 as a mean of the individual objects
for every fold for the test data.
reliability$iota_object_end_free: Object of class iotarelr_iota2 as a mean of the individual objects
for every fold. Please note that the model is estimated without forcing the Assignment Error Matrix to be in
line with the assumption of weak superiority.
reliability$standard_measures_end: Object of class list containing the final measures for precision,
recall, and f1 for every fold.
reliability$standard_measures_mean: matrix containing the mean measures for precision, recall, and f1.
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()aifeducation::AIFEBaseModel$get_all_fields()aifeducation::AIFEBaseModel$get_documentation_license()aifeducation::AIFEBaseModel$get_ml_framework()aifeducation::AIFEBaseModel$get_model_description()aifeducation::AIFEBaseModel$get_model_info()aifeducation::AIFEBaseModel$get_model_license()aifeducation::AIFEBaseModel$get_package_versions()aifeducation::AIFEBaseModel$get_private()aifeducation::AIFEBaseModel$get_publication_info()aifeducation::AIFEBaseModel$get_sustainability_data()aifeducation::AIFEBaseModel$is_configured()aifeducation::AIFEBaseModel$is_trained()aifeducation::AIFEBaseModel$load()aifeducation::AIFEBaseModel$set_documentation_license()aifeducation::AIFEBaseModel$set_model_description()aifeducation::AIFEBaseModel$set_model_license()aifeducation::AIFEBaseModel$set_publication_info()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()
predict()Method for predicting new data with a trained neural net.
ClassifiersBasedOnTextEmbeddings$predict(
newdata,
batch_size = 32,
ml_trace = 1
)newdataObject of class TextEmbeddingModel or LargeDataSetForTextEmbeddings for which predictions
should be made. In addition, this method allows to use objects of class array and
datasets.arrow_dataset.Dataset. However, these should be used only by developers.
batch_sizeint Size of batches.
ml_traceint ml_trace=0 does not print any information on the process from the machine learning
framework.
Returns a data.frame containing the predictions and the probabilities of the different labels for each
case.
check_embedding_model()Method for checking if the provided text embeddings are created with the same TextEmbeddingModel as the classifier.
ClassifiersBasedOnTextEmbeddings$check_embedding_model(
text_embeddings,
require_compressed = FALSE
)text_embeddingsObject of class EmbeddedText or LargeDataSetForTextEmbeddings.
require_compressedTRUE if a compressed version of the embeddings are necessary. Compressed embeddings
are created by an object of class TEFeatureExtractor.
TRUE if the underlying TextEmbeddingModel is the same. FALSE if the models differ.
check_feature_extractor_object_type()Method for checking an object of class TEFeatureExtractor.
ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type(
feature_extractor
)feature_extractorObject of class TEFeatureExtractor
This method does nothing returns. It raises an error if
the object is NULL
the object does not rely on the same machine learning framework as the classifier
the object is not trained.
requires_compression()Method for checking if provided text embeddings must be compressed via a TEFeatureExtractor before processing.
ClassifiersBasedOnTextEmbeddings$requires_compression(text_embeddings)text_embeddingsObject of class EmbeddedText, LargeDataSetForTextEmbeddings, array or
datasets.arrow_dataset.Dataset.
Return TRUE if a compression is necessary and FALSE if not.
save()Method for saving a model.
ClassifiersBasedOnTextEmbeddings$save(dir_path, folder_name)dir_pathstring Path of the directory where the model should be saved.
folder_namestring Name of the folder that should be created within the directory.
Function does not return a value. It saves the model to disk.
load_from_disk()loads an object from disk and updates the object to the current version of the package.
ClassifiersBasedOnTextEmbeddings$load_from_disk(dir_path)dir_pathPath where the object set is stored.
Method does not return anything. It loads an object from disk.
adjust_target_levels()Method transforms the levels of a factor into numbers corresponding to the models definition.
ClassifiersBasedOnTextEmbeddings$adjust_target_levels(data_targets)data_targetsfactor containing the labels for cases stored in embeddings. Factor must be
named and has to use the same names as used in in the embeddings.
Method returns a factor containing the numerical representation of
categories/classes.
plot_training_history()Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
ClassifiersBasedOnTextEmbeddings$plot_training_history(
final_training = FALSE,
pl_step = NULL,
measure = "loss",
y_min = NULL,
y_max = NULL,
add_min_max = TRUE,
text_size = 10
)final_trainingbool If FALSE the values of the performance estimation are used. If TRUE only
the epochs of the final training are used.
pl_stepint Number of the step during pseudo labeling to plot. Only relevant if the model was trained
with active pseudo labeling.
measurestring Measure to plot. Allowed values:
"avg_iota" = Average Iota
"loss" = Loss
"accuracy" = Accuracy
"balanced_accuracy" = Balanced Accuracy
y_minMinimal value for the y-axis. Set to NULL for an automatic adjustment.
y_maxMaximal value for the y-axis. Set to NULL for an automatic adjustment.
add_min_maxbool If TRUE the minimal and maximal values during performance estimation are port of the plot.
If FALSE only the mean values are shown. Parameter is ignored if final_training=TRUE.
text_sizeSize of the text.
Returns a plot of class ggplot visualizing the training process.
plot_coding_stream()Method for requesting a plot the coding stream. The plot shows how the cases of different categories/classes are assigned to a the available classes/categories. The visualization is helpful for analyzing the consequences of coding errors.
ClassifiersBasedOnTextEmbeddings$plot_coding_stream(
label_categories_size = 3,
key_size = 0.5,
text_size = 10
)label_categories_sizedouble determining the size of the label for each true and assigned category within the plot.
key_sizedouble determining the size of the legend.
text_sizedouble determining the size of the text within the legend.
Returns a plot of class ggplot visualizing the training process.
clone()The objects of this class are cloneable with this method.
ClassifiersBasedOnTextEmbeddings$clone(deep = FALSE)deepWhether to make a deep clone.
Other R6 Classes for Developers:
AIFEBaseModel,
LargeDataSetBase,
ModelsBasedOnTextEmbeddings,
TEClassifiersBasedOnProtoNet,
TEClassifiersBasedOnRegular