if (requireNamespace("mlr3filters")) {
if (requireNamespace("rpart")) {
library("mlr3")
library("mlr3filters")
data.table::setDTthreads(1)
# setup PipeOpFilter to keep the 5 most important
# features of the spam task w.r.t. their AUC
task = tsk("spam")
filter = flt("auc")
po = po("filter", filter = filter)
po$param_set
po$param_set$values$filter.nfeat = 5
# filter the task
filtered_task = po$train(list(task))[[1]]
# filtered task + extracted AUC scores
filtered_task$feature_names
head(po$state$scores, 10)
# feature selection embedded in a 3-fold cross validation
# keep 30% of features based on their AUC score
task = tsk("spam")
gr = po("filter", filter = flt("auc"), filter.frac = 0.5) %>>%
po("learner", lrn("classif.rpart"))
learner = GraphLearner$new(gr)
rr = resample(task, learner, rsmp("holdout"), store_models = TRUE)
rr$learners[[1]]$model$auc$scores
}
}
Run the code above in your browser using DataLab