if (! requireNamespace("rCBA", quietly = TRUE)) {
message("Please install rCBA to allow for sbrl model conversion")
return()
} else if (! requireNamespace("sbrl", quietly = TRUE)) {
message("Please install sbrl to allow for postprocessing of sbrl models")
} else
{
#' # This will run only outside a CRAN test, if the environment variable NOT_CRAN is set to true
# This environment variable is set by devtools
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
library(sbrl)
library(rCBA)
# sbrl handles only binary problems, iris has 3 target classes - remove one class
set.seed(getOption("qcba.seed"))
allData <- datasets::iris[sample(nrow(datasets::iris)),]
classToExclude<-"versicolor"
allData <- allData[allData$Species!=classToExclude, ]
# drop the removed level
allData$Species <-allData$Species [, drop=TRUE]
trainFold <- allData[1:50,]
testFold <- allData[51:nrow(allData),]
sbrlFixedLabel<-"label"
origLabel<-"Species"
orignames<-colnames(trainFold)
orignames[which(orignames == origLabel)]<-sbrlFixedLabel
colnames(trainFold)<-orignames
colnames(testFold)<-orignames
# to recode label to binary values:
# first create dict mapping from original distinct class values to 0,1
origval<-levels(as.factor(trainFold$label))
newval<-range(0,1)
dict<-data.frame(origval,newval)
# then apply dict to train and test fold
trainFold$label<-dict[match(trainFold$label, dict$origval), 2]
testFold$label<-dict[match(testFold$label, dict$origval), 2]
# discretize training data
trainFoldDiscTemp <- discrNumeric(trainFold, sbrlFixedLabel)
trainFoldDiscCutpoints <- trainFoldDiscTemp$cutp
trainFoldDisc <- as.data.frame(lapply(trainFoldDiscTemp$Disc.data, as.factor))
# discretize test data
testFoldDisc <- applyCuts(testFold, trainFoldDiscCutpoints, infinite_bounds=TRUE, labels=TRUE)
# SBRL 1.4 crashes if features contain a space
# even if these features are converted to factors,
# to circumvent this, it is necessary to replace spaces
trainFoldDisc <- as.data.frame(lapply(trainFoldDisc, function(x) gsub(" ", "", as.character(x))))
for (name in names(trainFoldDisc)) {trainFoldDisc[name] <- as.factor(trainFoldDisc[,name])}
# learn sbrl model, rule_minlen is increased to demonstrate the effect of postprocessing
sbrl_model <- sbrl(trainFoldDisc, iters=20000, pos_sign="0",
neg_sign="1", rule_minlen=3, rule_maxlen=5, minsupport_pos=0.05, minsupport_neg=0.05,
lambda=20.0, eta=5.0, nchain=25)
# apply sbrl model on a test fold
yhat <- predict(sbrl_model, testFoldDisc)
yvals<- as.integer(yhat$V1>0.5)
sbrl_acc<-mean(as.integer(yvals == testFoldDisc$label))
message("SBRL RESULT")
message(sbrl_model)
rm_sbrl<-sbrlModel2arcCBARuleModel(sbrl_model,trainFoldDiscCutpoints,trainFold,sbrlFixedLabel)
message(paste("sbrl acc=",sbrl_acc,", sbrl rule count=",nrow(sbrl_model$rs), ",
avg condition count (incl. default rule)",
sum(rm_sbrl@rules@lhs@data)/length(rm_sbrl@rules)))
rmQCBA_sbrl <- qcba(cbaRuleModel=rm_sbrl,datadf=trainFold)
prediction <- predict(rmQCBA_sbrl,testFold)
acc_qcba_sbrl <- CBARuleModelAccuracy(prediction, testFold[[rmQCBA_sbrl@classAtt]])
avg_rule_length <- rmQCBA_sbrl@rules$condition_count/nrow(rmQCBA_sbrl@rules)
message("RESULT of QCBA postprocessing of SBRL")
message(rmQCBA_sbrl@rules)
message(paste("QCBA after SBRL acc=",acc_qcba_sbrl,", rule count=",
rmQCBA_sbrl@ruleCount, ", avg condition count (incl. default rule)", avg_rule_length))
unlink("tdata_R.label") # delete temp files created by SBRL
unlink("tdata_R.out")
}
}
Run the code above in your browser using DataLab