Abstract class for neural nets with 'keras'/'tensorflow' and 'pytorch'.
This object represents in implementation of a prototypical network for few-shot learning as described by Snell, Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
classifier an object of class EmbeddedText or LargeDataSetForTextEmbeddings and a factor are necessary. The
object of class EmbeddedText or LargeDataSetForTextEmbeddings contains the numerical text representations (text
embeddings) of the raw texts generated by an object of class TextEmbeddingModel. The factor contains the
classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
class EmbeddedText or LargeDataSetForTextEmbeddings has to be used which was created with the same
TextEmbeddingModel as for training.
aifeducation::AIFEBaseModel -> aifeducation::TEClassifierRegular -> TEClassifierProtoNet
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()aifeducation::AIFEBaseModel$get_all_fields()aifeducation::AIFEBaseModel$get_documentation_license()aifeducation::AIFEBaseModel$get_ml_framework()aifeducation::AIFEBaseModel$get_model_description()aifeducation::AIFEBaseModel$get_model_info()aifeducation::AIFEBaseModel$get_model_license()aifeducation::AIFEBaseModel$get_package_versions()aifeducation::AIFEBaseModel$get_private()aifeducation::AIFEBaseModel$get_publication_info()aifeducation::AIFEBaseModel$get_sustainability_data()aifeducation::AIFEBaseModel$get_text_embedding_model()aifeducation::AIFEBaseModel$get_text_embedding_model_name()aifeducation::AIFEBaseModel$is_configured()aifeducation::AIFEBaseModel$load()aifeducation::AIFEBaseModel$set_documentation_license()aifeducation::AIFEBaseModel$set_model_description()aifeducation::AIFEBaseModel$set_model_license()aifeducation::AIFEBaseModel$set_publication_info()aifeducation::TEClassifierRegular$check_embedding_model()aifeducation::TEClassifierRegular$check_feature_extractor_object_type()aifeducation::TEClassifierRegular$load_from_disk()aifeducation::TEClassifierRegular$predict()aifeducation::TEClassifierRegular$requires_compression()aifeducation::TEClassifierRegular$save()
configure()Creating a new instance of this class.
TEClassifierProtoNet$configure(
ml_framework = "pytorch",
name = NULL,
label = NULL,
text_embeddings = NULL,
feature_extractor = NULL,
target_levels = NULL,
dense_size = 4,
dense_layers = 0,
rec_size = 4,
rec_layers = 2,
rec_type = "gru",
rec_bidirectional = FALSE,
embedding_dim = 2,
self_attention_heads = 0,
intermediate_size = NULL,
attention_type = "fourier",
add_pos_embedding = TRUE,
rec_dropout = 0.1,
repeat_encoder = 1,
dense_dropout = 0.4,
recurrent_dropout = 0.4,
encoder_dropout = 0.1,
optimizer = "adam"
)ml_frameworkstring Currently only pytorch is supported (ml_framework="pytorch").
namestring Name of the new classifier. Please refer to common name conventions. Free text can be used
with parameter label.
labelstring Label for the new classifier. Here you can use free text.
text_embeddingsAn object of class TextEmbeddingModel or LargeDataSetForTextEmbeddings.
feature_extractorObject of class TEFeatureExtractor which should be used in order to reduce the number
of dimensions of the text embeddings. If no feature extractor should be applied set NULL.
target_levelsvector containing the levels (categories or classes) within the target data. Please not
that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels
indicating a higher category/class. For nominal data the order does not matter.
dense_sizeint Number of neurons for each dense layer.
dense_layersint Number of dense layers.
rec_sizeint Number of neurons for each recurrent layer.
rec_layersint Number of recurrent layers.
rec_typestring Type of the recurrent layers.rec_type="gru" for Gated Recurrent Unit and
rec_type="lstm" for Long Short-Term Memory.
rec_bidirectionalbool If TRUE a bidirectional version of the recurrent layers is used.
embedding_dimint determining the number of dimensions for the text embedding.
self_attention_headsint determining the number of attention heads for a self-attention layer. Only
relevant if attention_type="multihead".
intermediate_sizeint determining the size of the projection layer within a each transformer encoder.
attention_typestring Choose the relevant attention type. Possible values are "fourier" and
"multihead". Please note that you may see different values for a case for different input orders if you choose fourier on linux.
add_pos_embeddingbool TRUE if positional embedding should be used.
rec_dropoutdouble ranging between 0 and lower 1, determining the dropout between bidirectional
recurrent layers.
repeat_encoderint determining how many times the encoder should be added to the network.
dense_dropoutdouble ranging between 0 and lower 1, determining the dropout between dense layers.
recurrent_dropoutdouble ranging between 0 and lower 1, determining the recurrent dropout for each
recurrent layer. Only relevant for keras models.
encoder_dropoutdouble ranging between 0 and lower 1, determining the dropout for the dense projection
within the encoder layers.
optimizerstring "adam" or "rmsprop" .
Returns an object of class TEClassifierProtoNet which is ready for training.
train()Method for training a neural net.
Training includes a routine for early stopping. In the case that loss<0.0001 and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values of the last trained epoch for the remaining epochs.
After training the model with the best values for Average Iota, Accuracy, and Loss on the validation data set is used as the final model.
TEClassifierProtoNet$train(
data_embeddings,
data_targets,
data_folds = 5,
data_val_size = 0.25,
use_sc = TRUE,
sc_method = "dbsmote",
sc_min_k = 1,
sc_max_k = 10,
use_pl = TRUE,
pl_max_steps = 3,
pl_max = 1,
pl_anchor = 1,
pl_min = 0,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 35,
Ns = 5,
Nq = 3,
loss_alpha = 0.5,
loss_margin = 0.5,
sampling_separate = FALSE,
sampling_shuffle = TRUE,
dir_checkpoint,
trace = TRUE,
ml_trace = 1,
log_dir = NULL,
log_write_interval = 10,
n_cores = auto_n_cores()
)data_embeddingsObject of class EmbeddedText or LargeDataSetForTextEmbeddings.
data_targetsfactor containing the labels for cases stored in data_embeddings. Factor must be named
and has to use the same names used in data_embeddings.
data_foldsint determining the number of cross-fold samples.
data_val_sizedouble between 0 and 1, indicating the proportion of cases of each class which should be
used for the validation sample during the estimation of the model. The remaining cases are part of the training
data.
use_scbool TRUE if the estimation should integrate synthetic cases. FALSE if not.
sc_methodvector containing the method for generating synthetic cases. Possible are sc_method="adas",
sc_method="smote", and sc_method="dbsmote".
sc_min_kint determining the minimal number of k which is used for creating synthetic units.
sc_max_kint determining the maximal number of k which is used for creating synthetic units.
use_plbool TRUE if the estimation should integrate pseudo-labeling. FALSE if not.
pl_max_stepsint determining the maximum number of steps during pseudo-labeling.
pl_maxdouble between 0 and 1, setting the maximal level of confidence for considering a case for
pseudo-labeling.
pl_anchordouble between 0 and 1 indicating the reference point for sorting the new cases of every
label. See notes for more details.
pl_mindouble between 0 and 1, setting the minimal level of confidence for considering a case for
pseudo-labeling.
sustain_trackbool If TRUE energy consumption is tracked during training via the python library
'codecarbon'.
sustain_iso_codestring ISO code (Alpha-3-Code) for the country. This variable must be set if
sustainability should be tracked. A list can be found on Wikipedia:
https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.
sustain_regionRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
sustain_intervalint Interval in seconds for measuring power usage.
epochsint Number of training epochs.
batch_sizeint Size of the batches for training.
Nsint Number of cases for every class in the sample.
Nqint Number of cases for every class in the query.
loss_alphadouble Value between 0 and 1 indicating how strong the loss should focus on pulling cases to
its corresponding prototypes or pushing cases away from other prototypes. The higher the value the more the
loss concentrates on pulling cases to its corresponding prototypes.
loss_margindouble Value greater 0 indicating the minimal distance of every case from prototypes of
other classes
sampling_separatebool If TRUE the cases for every class are divided into a data set for sample and for query.
These are never mixed. If TRUE sample and query cases are drawn from the same data pool. That is, a case can be
part of sample in one epoch and in another epoch it can be part of query. It is ensured that a case is never part of
sample and query at the same time. In addition, it is ensured that every cases exists only once during
a training step.
sampling_shufflebool If TRUE cases a randomly drawn from the data during every step. If FALSE
the cases are not shuffled.
dir_checkpointstring Path to the directory where the checkpoint during training should be saved. If the
directory does not exist, it is created.
tracebool TRUE, if information about the estimation phase should be printed to the console.
ml_traceint ml_trace=0 does not print any information about the training process from pytorch on the
console.
log_dirstring Path to the directory where the log files should be saved. If no logging is desired set
this argument to NULL.
log_write_intervalint Time in seconds determining the interval in which the logger should try to update
the log files. Only relevant if log_dir is not NULL.
n_coresint Number of cores which should be used during the calculation of synthetic cases. Only relevant if
use_sc=TRUE.
balance_class_weightsbool If TRUE class weights are generated based on the frequencies of the
training data with the method Inverse Class Frequency'. If FALSE each class has the weight 1.
balance_sequence_lengthbool If TRUE sample weights are generated for the length of sequences based on
the frequencies of the training data with the method Inverse Class Frequency'. If FALSE each sequences length
has the weight 1.
sc_max_k: All values from sc_min_k up to sc_max_k are successively used. If
the number of sc_max_k is too high, the value is reduced to a number that allows the calculating of synthetic
units.
pl_anchor: With the help of this value, the new cases are sorted. For
this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
Function does not return a value. It changes the object into a trained classifier.
embed()Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of texts created by an object of class TextEmbeddingModel. These embeddings embed documents according to their similarity to specific classes.
TEClassifierProtoNet$embed(embeddings_q = NULL, batch_size = 32)embeddings_qObject of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
batch_sizeint batch size.
Returns a list containing the following elements
embeddings_q: embeddings for the cases (query sample).
embeddings_prototypes: embeddings of the prototypes which were learned during training. They represents the
center for the different classes.
plot_embeddings()Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
TEClassifierProtoNet$plot_embeddings(
embeddings_q,
classes_q = NULL,
batch_size = 12,
alpha = 0.5,
size_points = 3,
size_points_prototypes = 8,
inc_unlabeled = TRUE
)embeddings_qObject of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
classes_qNamed factor containg the true classes for every case. Please note that the names must match
the names/ids in embeddings_q.
batch_sizeint batch size.
alphafloat Value indicating how transparent the points should be (important
if many points overlap). Does not apply to points representing prototypes.
size_pointsint Size of the points excluding the points for prototypes.
size_points_prototypesint Size of points representing prototypes.
inc_unlabeledbool If TRUE plot includes unlabeled cases as data points.
Returns a plot of class ggplotvisualizing embeddings.
clone()The objects of this class are cloneable with this method.
TEClassifierProtoNet$clone(deep = FALSE)deepWhether to make a deep clone.
Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning. https://doi.org/10.48550/arXiv.1703.05175
Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H. Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing. https://doi.org/10.1007/978-3-030-16145-3_24
Other Classification:
TEClassifierRegular