Learn R Programming

aifeducation (version 1.1.2)

BaseModelCore: Abstract class for all BaseModels

Description

This class contains all methods shared by all BaseModels.

Arguments

Value

Does return a new object of this class.

Super classes

aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> BaseModelCore

Public fields

Tokenizer

('TokenizerBase')
Objects of class TokenizerBase.

Methods

Inherited methods


Method create_from_hf()

Creates BaseModel from a pretrained model

Usage

BaseModelCore$create_from_hf(model_dir = NULL, tokenizer_dir = NULL)

Arguments

model_dir

tokenizer_dir

string Path to the directory where the tokenizer is saved. Allowed values: any

Returns

Does return a new object of this class.


Method train()

Traines a BaseModel

Usage

BaseModelCore$train(
  text_dataset,
  p_mask = 0.15,
  whole_word = TRUE,
  val_size = 0.1,
  n_epoch = 1L,
  batch_size = 12L,
  max_sequence_length = 250L,
  full_sequences_only = FALSE,
  min_seq_len = 50L,
  learning_rate = 0.003,
  sustain_track = FALSE,
  sustain_iso_code = NULL,
  sustain_region = NULL,
  sustain_interval = 15L,
  sustain_log_level = "warning",
  trace = TRUE,
  pytorch_trace = 1L,
  log_dir = NULL,
  log_write_interval = 2L
)

Arguments

text_dataset

p_mask

whole_word

val_size

n_epoch

batch_size

max_sequence_length

full_sequences_only

min_seq_len

learning_rate

sustain_track

sustain_iso_code

sustain_region

sustain_interval

sustain_log_level

trace

pytorch_trace

log_dir

log_write_interval

Returns

Does nothing return.


Method count_parameter()

Method for counting the trainable parameters of a model.

Usage

BaseModelCore$count_parameter()

Returns

Returns the number of trainable parameters of the model.


Method plot_training_history()

Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.

Usage

BaseModelCore$plot_training_history(
  y_min = NULL,
  y_max = NULL,
  text_size = 10L
)

Arguments

y_min

y_max

text_size

Returns

Returns a plot of class ggplot visualizing the training process.


Method get_special_tokens()

Method for receiving the special tokens of the model

Usage

BaseModelCore$get_special_tokens()

Returns

Returns a matrix containing the special tokens in the rows and their type, token, and id in the columns.


Method get_tokenizer_statistics()

Tokenizer statistics

Usage

BaseModelCore$get_tokenizer_statistics()

Returns

Returns a data.frame containing the tokenizer's statistics.


Method fill_mask()

Method for calculating tokens behind mask tokens.

Usage

BaseModelCore$fill_mask(masked_text, n_solutions = 5L)

Arguments

masked_text

n_solutions

Returns

Returns a list containing a data.frame for every mask. The data.frame contains the solutions in the rows and reports the score, token id, and token string in the columns.


Method save()

Method for saving a model on disk.

Usage

BaseModelCore$save(dir_path, folder_name)

Arguments

dir_path

Path to the directory where to save the object.

folder_name

string Name of the folder where the model should be saved. Allowed values: any

Returns

Function does nothing return. It is used to save an object on disk.


Method load_from_disk()

Loads an object from disk and updates the object to the current version of the package.

Usage

BaseModelCore$load_from_disk(dir_path)

Arguments

dir_path

Path where the object set is stored.

Returns

Function does nothin return. It loads an object from disk.


Method get_model()

Get 'PyTorch' model

Usage

BaseModelCore$get_model()

Returns

Returns the underlying 'PyTorch' model.


Method get_model_type()

Type of the underlying model.

Usage

BaseModelCore$get_model_type()

Returns

Returns a string describing the model's architecture.


Method get_final_size()

Size of the final layer.

Usage

BaseModelCore$get_final_size()

Returns

Returns an int describing the number of dimensions of the last hidden layer.


Method get_flops_estimates()

Flop estimates

Usage

BaseModelCore$get_flops_estimates()

Returns

Returns a data.frame containing statistics about the flops.


Method set_publication_info()

Method for setting the bibliographic information of the model.

Usage

BaseModelCore$set_publication_info(type, authors, citation, url = NULL)

Arguments

type

string Type of information which should be changed/added. developer, and modifier are possible.

authors

List of people.

citation

string Citation in free text.

url

string Corresponding URL if applicable.

Returns

Function does not return a value. It is used to set the private members for publication information of the model.


Method estimate_sustainability_inference_fill_mask()

Calculates the energy consumption for inference of the given task.

Usage

BaseModelCore$estimate_sustainability_inference_fill_mask(
  text_dataset = NULL,
  n = NULL,
  sustain_iso_code = NULL,
  sustain_region = NULL,
  sustain_interval = 15L,
  sustain_log_level = "warning",
  trace = TRUE
)

Arguments

text_dataset

n

sustain_iso_code

sustain_region

sustain_interval

sustain_log_level

trace

Returns

Returns nothing. Method saves the statistics internally. The statistics can be accessed with the method get_sustainability_data("inference")


Method calc_flops_architecture_based()

Calculates FLOPS based on model's architecture.

Usage

BaseModelCore$calc_flops_architecture_based(batch_size, n_batches, n_epochs)

Arguments

batch_size

n_batches

n_epochs

Returns

Returns a data.frame storing the estimates.


Method clone()

The objects of this class are cloneable with this method.

Usage

BaseModelCore$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

See Also

Other R6 Classes for Developers: AIFEBaseModel, AIFEMaster, ClassifiersBasedOnTextEmbeddings, DataManagerClassifier, LargeDataSetBase, ModelsBasedOnTextEmbeddings, TEClassifiersBasedOnProtoNet, TEClassifiersBasedOnRegular, TokenizerBase