Represents models based on MPNet.
Does return a new object of this class.
aifeducation::AIFEMaster
-> aifeducation::AIFEBaseModel
-> aifeducation::BaseModelCore
-> BaseModelMPNet
Inherited methods
aifeducation::AIFEMaster$get_all_fields()
aifeducation::AIFEMaster$get_documentation_license()
aifeducation::AIFEMaster$get_ml_framework()
aifeducation::AIFEMaster$get_model_config()
aifeducation::AIFEMaster$get_model_description()
aifeducation::AIFEMaster$get_model_info()
aifeducation::AIFEMaster$get_model_license()
aifeducation::AIFEMaster$get_package_versions()
aifeducation::AIFEMaster$get_private()
aifeducation::AIFEMaster$get_publication_info()
aifeducation::AIFEMaster$get_sustainability_data()
aifeducation::AIFEMaster$is_configured()
aifeducation::AIFEMaster$is_trained()
aifeducation::AIFEMaster$set_documentation_license()
aifeducation::AIFEMaster$set_model_description()
aifeducation::AIFEMaster$set_model_license()
aifeducation::BaseModelCore$calc_flops_architecture_based()
aifeducation::BaseModelCore$count_parameter()
aifeducation::BaseModelCore$create_from_hf()
aifeducation::BaseModelCore$estimate_sustainability_inference_fill_mask()
aifeducation::BaseModelCore$fill_mask()
aifeducation::BaseModelCore$get_final_size()
aifeducation::BaseModelCore$get_flops_estimates()
aifeducation::BaseModelCore$get_model()
aifeducation::BaseModelCore$get_model_type()
aifeducation::BaseModelCore$get_special_tokens()
aifeducation::BaseModelCore$get_tokenizer_statistics()
aifeducation::BaseModelCore$load_from_disk()
aifeducation::BaseModelCore$plot_training_history()
aifeducation::BaseModelCore$save()
aifeducation::BaseModelCore$set_publication_info()
configure()
Configures a new object of this class.
BaseModelMPNet$configure(
tokenizer,
max_position_embeddings = 512L,
hidden_size = 768L,
num_hidden_layers = 12L,
num_attention_heads = 12L,
intermediate_size = 3072L,
hidden_act = "GELU",
hidden_dropout_prob = 0.1,
attention_probs_dropout_prob = 0.1
)
tokenizer
TokenizerBase
Tokenizer for the model.
max_position_embeddings
int
Number of maximum position embeddings. This parameter also determines the maximum length of a sequence which
can be processed with the model. Allowed values: 10 <= x <= 4048
hidden_size
int
Number of neurons in each layer. This parameter determines the dimensionality of the resulting text
embedding. Allowed values: 1 <= x <= 2048
num_hidden_layers
int
Number of hidden layers. Allowed values: 1 <= x
num_attention_heads
int
determining the number of attention heads for a self-attention layer. Only relevant if attention_type='multihead'
Allowed values: 0 <= x
intermediate_size
int
determining the size of the projection layer within a each transformer encoder. Allowed values: 1 <= x
hidden_act
string
Name of the activation function. Allowed values: 'GELU', 'relu', 'silu', 'gelu_new'
hidden_dropout_prob
double
Ratio of dropout. Allowed values: 0 <= x <= 0.6
attention_probs_dropout_prob
double
Ratio of dropout for attention probabilities. Allowed values: 0 <= x <= 0.6
Does nothing return.
train()
Traines a BaseModel
BaseModelMPNet$train(
text_dataset,
p_mask = 0.15,
p_perm = 0.15,
whole_word = TRUE,
val_size = 0.1,
n_epoch = 1L,
batch_size = 12L,
max_sequence_length = 250L,
full_sequences_only = FALSE,
min_seq_len = 50L,
learning_rate = 0.003,
sustain_track = FALSE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
trace = TRUE,
pytorch_trace = 1L,
log_dir = NULL,
log_write_interval = 2L
)
text_dataset
p_mask
p_perm
whole_word
val_size
n_epoch
batch_size
max_sequence_length
full_sequences_only
min_seq_len
learning_rate
sustain_track
sustain_iso_code
sustain_region
sustain_interval
sustain_log_level
trace
pytorch_trace
log_dir
log_write_interval
Does nothing return.
clone()
The objects of this class are cloneable with this method.
BaseModelMPNet$clone(deep = FALSE)
deep
Whether to make a deep clone.
Song,K., Tan, X., Qin, T., Lu, J. & Liu, T.-Y. (2020). MPNet: Masked and Permuted Pre-training for Language Understanding. tools:::Rd_expr_doi("10.48550/arXiv.2004.09297")
Other Base Model:
BaseModelBert
,
BaseModelDebertaV2
,
BaseModelFunnel
,
BaseModelModernBert
,
BaseModelRoberta