# if (! requireNamespace("rCBA", quietly = TRUE)) {
# message("Please install rCBA to allow for sbrl model conversion")
# return()
# } else if (! requireNamespace("sbrl", quietly = TRUE)) {
# message("Please install sbrl to allow for postprocessing of sbrl models")
#} else
#{
# library(sbrl)
# library(rCBA)
# #sbrl handles only binary problems, iris has 3 target classes - remove one class
# set.seed(111)
# allData <- datasets::iris[sample(nrow(datasets::iris)),]
# classToExclude<-"versicolor"
# allData <- allData[allData$Species!=classToExclude, ]
# # drop virginica level
# allData$Species <-allData$Species [, drop=TRUE]
# trainFold <- allData[1:50,]
# testFold <- allData[51:nrow(allData),]
# sbrlFixedLabel<-"label"
# origLabel<-"Species"
# orignames<-colnames(trainFold)
# orignames[which(orignames == origLabel)]<-sbrlFixedLabel
# colnames(trainFold)<-orignames
# colnames(testFold)<-orignames
# # to recode label to binary values:
# # first create dict mapping from original distinct class values to 0,1
# origval<-levels(as.factor(trainFold$label))
# newval<-range(0,1)
# dict<-data.frame(origval,newval)
# # then apply dict to train and test fold
# trainFold$label<-dict[match(trainFold$label, dict$origval), 2]
# testFold$label<-dict[match(testFold$label, dict$origval), 2]
# # discretize training data
# trainFoldDiscTemp <- discrNumeric(trainFold, sbrlFixedLabel)
# trainFoldDiscCutpoints <- trainFoldDiscTemp$cutp
# trainFoldDisc <- as.data.frame(lapply(trainFoldDiscTemp$Disc.data, as.factor))
# # discretize test data
# testFoldDisc <- applyCuts(testFold, trainFoldDiscCutpoints, infinite_bounds=TRUE, labels=TRUE)
# # learn sbrl model
# sbrl_model <- sbrl(trainFoldDisc, iters=30000, pos_sign="0",
# neg_sign="1", rule_minlen=1, rule_maxlen=10,
# minsupport_pos=0.10, minsupport_neg=0.10,
# lambda=10.0, eta=1.0, alpha=c(1,1), nchain=10)
# # apply sbrl model on a test fold
# yhat <- predict(sbrl_model, testFoldDisc)
# yvals<- as.integer(yhat$V1>0.5)
# sbrl_acc<-mean(as.integer(yvals == testFoldDisc$label))
# message("SBRL RESULT")
# sbrl_model
# rm_sbrl<-sbrlModel2arcCBARuleModel(sbrl_model,trainFoldDiscCutpoints,trainFold,sbrlFixedLabel)
# message(paste("sbrl acc=",sbrl_acc,"sbrl rule count=",nrow(sbrl_model$rs), "avg rule length",
# sum(rm_sbrl@rules@lhs@data)/length(rm_sbrl@rules)))
# rmQCBA_sbrl <- qcba(cbaRuleModel=rm_sbrl,datadf=trainFold)
# prediction <- predict(rmQCBA_sbrl,testFold)
# acc_qcba_sbrl <- CBARuleModelAccuracy(prediction, testFold[[rmQCBA_sbrl@classAtt]])
# if (! requireNamespace("stringr", quietly = TRUE)) {
# message("Please install stringr to compute average rule length for QCBA")
# avg_rule_length <- NA
# } else
# {
# library(stringr)
# avg_rule_length <- (sum(unlist(lapply(rmQCBA_sbrl@rules[1],str_count,pattern=",")))+
# # assuming the last rule has antecedent length zero
# nrow(rmQCBA_sbrl@rules)-1)/nrow(rmQCBA_sbrl@rules)
# }
# message("QCBA RESULT")
# rmQCBA_sbrl@rules
# message(paste("QCBA after SBRL acc=",acc_qcba_sbrl,"rule count=",
# rmQCBA_sbrl@ruleCount, "avg rule length", avg_rule_length))
# unlink("tdata_R.label") # delete temp files created by SBRL
# unlink("tdata_R.out")
# }
Run the code above in your browser using DataLab