Abstract class for managing the data and samples during training a classifier. DataManagerClassifier is used with all classifiers based on text embeddings.
Objects of this class are used for ensuring the correct data management for training different types of classifiers. They are also used for data augmentation by creating synthetic cases with different techniques.
config
('list')
Field for storing configuration of the DataManagerClassifier.
state
('list')
Field for storing the current state of the DataManagerClassifier.
datasets
('list')
Field for storing the data sets used during training. All elements of the list are data sets of class
datasets.arrow_dataset.Dataset
. The following data sets are available:
data_labeled: all cases which have a label.
data_unlabeled: all cases which have no label.
data_labeled_synthetic: all synthetic cases with their corresponding labels.
data_labeled_pseudo: subset of data_unlabeled if pseudo labels were estimated by a classifier.
name_idx
('named vector')
Field for storing the pairs of indexes and names of every case. The pairs for labeled and unlabeled data are
separated.
samples
('list')
Field for storing the assignment of every cases to a train, validation or test data set depending on the
concrete fold. Only the indexes and not the names are stored. In addition, the list contains the assignment for
the final training which excludes a test data set. If the DataManagerClassifier uses i
folds the sample for
the final training can be requested with i+1
.
new()
Creating a new instance of this class.
DataManagerClassifier$new(
data_embeddings,
data_targets,
folds = 5,
val_size = 0.25,
pad_value = -100,
class_levels,
one_hot_encoding = TRUE,
add_matrix_map = TRUE,
sc_methods = "knnor",
sc_min_k = 1,
sc_max_k = 10,
trace = TRUE,
n_cores = auto_n_cores()
)
data_embeddings
EmbeddedText, LargeDataSetForTextEmbeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.
data_targets
factor
containing the labels for cases stored in embeddings. Factor must be
named and has to use the same names as used in in the embeddings.
folds
int
determining the number of cross-fold samples. Allowed values: 1 <= x
val_size
double
between 0 and 1, indicating the proportion of cases which should be
used for the validation sample during the estimation of the model.
The remaining cases are part of the training data. Allowed values: 0 < x < 1
pad_value
int
Value indicating padding. This value should no be in the range of
regluar values for computations. Thus it is not recommended to chance this value.
Default is -100
. Allowed values: x <= -100
class_levels
vector
containing the levels (categories or classes) within the target data. Please
note that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels
indicating a higher category/class. For nominal data the order does not matter.
one_hot_encoding
bool
If TRUE
all labels are converted to one hot encoding.
add_matrix_map
bool
If TRUE
all embeddings are transformed into a two dimensional matrix.
The number of rows equals the number of cases. The number of columns equals times*features
.
sc_methods
string
containing the method for generating synthetic cases. Allowed values: 'knnor'
sc_min_k
int
determining the minimal number of k which is used for creating synthetic units. Allowed values: 1 <= x
sc_max_k
int
determining the maximal number of k which is used for creating synthetic units. Allowed values: 1 <= x
trace
bool
TRUE
if information about the estimation phase should be printed to the console.
n_cores
int
Number of cores which should be used during the calculation of synthetic cases. Only relevant if use_sc=TRUE
. Allowed values: 1 <= x
Method returns an initialized object of class DataManagerClassifier.
get_config()
Method for requesting the configuration of the DataManagerClassifier.
DataManagerClassifier$get_config()
Returns a list
storing the configuration of the DataManagerClassifier.
get_labeled_data()
Method for requesting the complete labeled data set.
DataManagerClassifier$get_labeled_data()
Returns an object of class datasets.arrow_dataset.Dataset
containing all cases with labels.
get_unlabeled_data()
Method for requesting the complete unlabeled data set.
DataManagerClassifier$get_unlabeled_data()
Returns an object of class datasets.arrow_dataset.Dataset
containing all cases without labels.
get_samples()
Method for requesting the assignments to train, validation, and test data sets for every fold and the final training.
DataManagerClassifier$get_samples()
Returns a list
storing the assignments to a train, validation, and test data set for every fold. In the
case of the sample for the final training the test data set is always empty (NULL
).
set_state()
Method for setting the current state of the DataManagerClassifier.
DataManagerClassifier$set_state(iteration, step = NULL)
iteration
int
determining the current iteration of the training. That is iteration determines the fold
to use for training, validation, and testing. If i is the number of fold i+1 request the sample for the
final training. For requesting the sample for the final training iteration can take a string "final"
.
step
int
determining the step for estimating and using pseudo labels during training. Only relevant if
training is requested with pseudo labels.
Method does not return anything. It is used for setting the internal state of the DataManager.
get_n_folds()
Method for requesting the number of folds the DataManagerClassifier can use with the current data.
DataManagerClassifier$get_n_folds()
Returns the number of folds the DataManagerClassifier uses.
get_n_classes()
Method for requesting the number of classes.
DataManagerClassifier$get_n_classes()
Returns the number classes.
get_statistics()
Method for requesting descriptive sample statistics.
DataManagerClassifier$get_statistics()
Returns a table describing the absolute frequencies of the labeled and unlabeled data. The rows contain the length of the sequences while the columns contain the labels.
contains_unlabeled_data()
Method for checking if the dataset contains cases without labels.
DataManagerClassifier$contains_unlabeled_data()
Returns TRUE
if the dataset contains cases without labels. Returns FALSE
if all cases have labels.
get_dataset()
Method for requesting a data set for training depending in the current state of the DataManagerClassifier.
DataManagerClassifier$get_dataset(
inc_labeled = TRUE,
inc_unlabeled = FALSE,
inc_synthetic = FALSE,
inc_pseudo_data = FALSE
)
inc_labeled
bool
If TRUE
the data set includes all cases which have labels.
inc_unlabeled
bool
If TRUE
the data set includes all cases which have no labels.
inc_synthetic
bool
If TRUE
the data set includes all synthetic cases with their corresponding labels.
inc_pseudo_data
bool
If TRUE
the data set includes all cases which have pseudo labels.
Returns an object of class datasets.arrow_dataset.Dataset
containing the requested kind of data along
with all requested transformations for training. Please note that this method returns a data sets that is
designed for training only. The corresponding validation data set is requested with get_val_dataset
and the
corresponding test data set with get_test_dataset
.
get_val_dataset()
Method for requesting a data set for validation depending in the current state of the DataManagerClassifier.
DataManagerClassifier$get_val_dataset()
Returns an object of class datasets.arrow_dataset.Dataset
containing the requested kind of data along
with all requested transformations for validation. The corresponding data set for training can be requested
with get_dataset
and the corresponding data set for testing with get_test_dataset
.
get_test_dataset()
Method for requesting a data set for testing depending in the current state of the DataManagerClassifier.
DataManagerClassifier$get_test_dataset()
Returns an object of class datasets.arrow_dataset.Dataset
containing the requested kind of data along
with all requested transformations for validation. The corresponding data set for training can be requested
with get_dataset
and the corresponding data set for validation with get_val_dataset
.
create_synthetic()
Method for generating synthetic data used during training. The process uses all labeled data belonging to the current state of the DataManagerClassifier.
DataManagerClassifier$create_synthetic(trace = TRUE, inc_pseudo_data = FALSE)
trace
bool
If TRUE
information on the process are printed to the console.
inc_pseudo_data
bool
If TRUE
data with pseudo labels are used in addition to the labeled data for
generating synthetic cases.
This method does nothing return. It generates a new data set for synthetic cases which are stored as an
object of class datasets.arrow_dataset.Dataset
in the field datasets$data_labeled_synthetic
. Please note
that a call of this method will override an existing data set in the corresponding field.
add_replace_pseudo_data()
Method for adding data with pseudo labels generated by a classifier
DataManagerClassifier$add_replace_pseudo_data(inputs, labels)
inputs
array
or matrix
representing the input data.
labels
factor
containing the corresponding pseudo labels.
This method does nothing return. It generates a new data set for synthetic cases which are stored as an
object of class datasets.arrow_dataset.Dataset
in the field datasets$data_labeled_pseudo
. Please note that
a call of this method will override an existing data set in the corresponding field.
clone()
The objects of this class are cloneable with this method.
DataManagerClassifier$clone(deep = FALSE)
deep
Whether to make a deep clone.