Base class for classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings generated with a TextEmbeddingModel.
Objects of this class containing fields and methods used in several other classes in 'AI for Education'.
This class is not designed for a direct application and should only be used by developers.
A new object of this class.
aifeducation::AIFEBaseModel
-> aifeducation::ModelsBasedOnTextEmbeddings
-> ClassifiersBasedOnTextEmbeddings
feature_extractor
('list()')
List for storing information and objects about the feature_extractor.
reliability
('list()')
List for storing central reliability measures of the last training.
reliability$test_metric
: Array containing the reliability measures for the test data for
every fold and step (in case of pseudo-labeling).
reliability$test_metric_mean
: Array containing the reliability measures for the test data.
The values represent the mean values for every fold.
reliability$raw_iota_objects
: List containing all iota_object generated with the package iotarelr
for every fold at the end of the last training for the test data.
reliability$raw_iota_objects$iota_objects_end
: List of objects with class iotarelr_iota2
containing the
estimated iota reliability of the second generation for the final model for every fold for the test data.
reliability$raw_iota_objects$iota_objects_end_free
: List of objects with class iotarelr_iota2
containing
the estimated iota reliability of the second generation for the final model for every fold for the test data.
Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the
assumption of weak superiority.
reliability$iota_object_end
: Object of class iotarelr_iota2
as a mean of the individual objects
for every fold for the test data.
reliability$iota_object_end_free
: Object of class iotarelr_iota2
as a mean of the individual objects
for every fold. Please note that the model is estimated without forcing the Assignment Error Matrix to be in
line with the assumption of weak superiority.
reliability$standard_measures_end
: Object of class list
containing the final measures for precision,
recall, and f1 for every fold.
reliability$standard_measures_mean
: matrix
containing the mean measures for precision, recall, and f1.
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$is_trained()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()
predict()
Method for predicting new data with a trained neural net.
ClassifiersBasedOnTextEmbeddings$predict(
newdata,
batch_size = 32,
ml_trace = 1
)
newdata
Object of class TextEmbeddingModel or LargeDataSetForTextEmbeddings for which predictions
should be made. In addition, this method allows to use objects of class array
and
datasets.arrow_dataset.Dataset
. However, these should be used only by developers.
batch_size
int
Size of batches.
ml_trace
int
ml_trace=0
does not print any information on the process from the machine learning
framework.
Returns a data.frame
containing the predictions and the probabilities of the different labels for each
case.
check_embedding_model()
Method for checking if the provided text embeddings are created with the same TextEmbeddingModel as the classifier.
ClassifiersBasedOnTextEmbeddings$check_embedding_model(
text_embeddings,
require_compressed = FALSE
)
text_embeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.
require_compressed
TRUE
if a compressed version of the embeddings are necessary. Compressed embeddings
are created by an object of class TEFeatureExtractor.
TRUE
if the underlying TextEmbeddingModel is the same. FALSE
if the models differ.
check_feature_extractor_object_type()
Method for checking an object of class TEFeatureExtractor.
ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type(
feature_extractor
)
feature_extractor
Object of class TEFeatureExtractor
This method does nothing returns. It raises an error if
the object is NULL
the object does not rely on the same machine learning framework as the classifier
the object is not trained.
requires_compression()
Method for checking if provided text embeddings must be compressed via a TEFeatureExtractor before processing.
ClassifiersBasedOnTextEmbeddings$requires_compression(text_embeddings)
text_embeddings
Object of class EmbeddedText, LargeDataSetForTextEmbeddings, array
or
datasets.arrow_dataset.Dataset
.
Return TRUE
if a compression is necessary and FALSE
if not.
save()
Method for saving a model.
ClassifiersBasedOnTextEmbeddings$save(dir_path, folder_name)
dir_path
string
Path of the directory where the model should be saved.
folder_name
string
Name of the folder that should be created within the directory.
Function does not return a value. It saves the model to disk.
load_from_disk()
loads an object from disk and updates the object to the current version of the package.
ClassifiersBasedOnTextEmbeddings$load_from_disk(dir_path)
dir_path
Path where the object set is stored.
Method does not return anything. It loads an object from disk.
adjust_target_levels()
Method transforms the levels of a factor into numbers corresponding to the models definition.
ClassifiersBasedOnTextEmbeddings$adjust_target_levels(data_targets)
data_targets
factor
containing the labels for cases stored in embeddings. Factor must be
named and has to use the same names as used in in the embeddings.
Method returns a factor
containing the numerical representation of
categories/classes.
plot_training_history()
Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
ClassifiersBasedOnTextEmbeddings$plot_training_history(
final_training = FALSE,
pl_step = NULL,
measure = "loss",
y_min = NULL,
y_max = NULL,
add_min_max = TRUE,
text_size = 10
)
final_training
bool
If FALSE
the values of the performance estimation are used. If TRUE
only
the epochs of the final training are used.
pl_step
int
Number of the step during pseudo labeling to plot. Only relevant if the model was trained
with active pseudo labeling.
measure
string
Measure to plot. Allowed values:
"avg_iota"
= Average Iota
"loss"
= Loss
"accuracy"
= Accuracy
"balanced_accuracy"
= Balanced Accuracy
y_min
Minimal value for the y-axis. Set to NULL
for an automatic adjustment.
y_max
Maximal value for the y-axis. Set to NULL
for an automatic adjustment.
add_min_max
bool
If TRUE
the minimal and maximal values during performance estimation are port of the plot.
If FALSE
only the mean values are shown. Parameter is ignored if final_training=TRUE
.
text_size
Size of the text.
Returns a plot of class ggplot
visualizing the training process.
plot_coding_stream()
Method for requesting a plot the coding stream. The plot shows how the cases of different categories/classes are assigned to a the available classes/categories. The visualization is helpful for analyzing the consequences of coding errors.
ClassifiersBasedOnTextEmbeddings$plot_coding_stream(
label_categories_size = 3,
key_size = 0.5,
text_size = 10
)
label_categories_size
double
determining the size of the label for each true and assigned category within the plot.
key_size
double
determining the size of the legend.
text_size
double
determining the size of the text within the legend.
Returns a plot of class ggplot
visualizing the training process.
clone()
The objects of this class are cloneable with this method.
ClassifiersBasedOnTextEmbeddings$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other R6 Classes for Developers:
AIFEBaseModel
,
LargeDataSetBase
,
ModelsBasedOnTextEmbeddings
,
TEClassifiersBasedOnProtoNet
,
TEClassifiersBasedOnRegular